This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new c897e4c [Feature] Support build and build_and_load. (#213)
c897e4c is described below
commit c897e4c9c2ed6f86cf5ef470a78e453eb040fb60
Author: DarkSharpness <[email protected]>
AuthorDate: Wed Nov 5 09:47:30 2025 +0800
[Feature] Support build and build_and_load. (#213)
Related issue #184 . This PR aims at improving the functionality of
`build_inline`. We implement a new `build`/`build_and_load` function
which accepts paths to cpp/cuda files instead of raw string of cpp/cuda
files. They have an interface similar to
[`torch.utils.cpp_extension.load`](https://docs.pytorch.org/docs/2.9/cpp_extension.html#torch.utils.cpp_extension.load).
---
docs/reference/python/index.rst | 5 +-
python/tvm_ffi/cpp/__init__.py | 9 +-
.../tvm_ffi/cpp/{load_inline.py => extension.py} | 485 ++++++++++++++++++---
.../utils/_build_optional_torch_c_dlpack.py | 2 +-
tests/python/test_build.cc | 39 ++
tests/python/test_build.h | 25 ++
.../cpp/__init__.py => tests/python/test_build.py | 26 +-
7 files changed, 521 insertions(+), 70 deletions(-)
diff --git a/docs/reference/python/index.rst b/docs/reference/python/index.rst
index 3a5541e..d36f025 100644
--- a/docs/reference/python/index.rst
+++ b/docs/reference/python/index.rst
@@ -102,7 +102,7 @@ Stream Context
get_raw_stream
-Inline Loading
+C++ Extension
--------------
C++ integration helpers for building and loading inline modules.
@@ -112,7 +112,8 @@ C++ integration helpers for building and loading inline
modules.
cpp.load_inline
cpp.build_inline
-
+ cpp.load
+ cpp.build
Misc
----
diff --git a/python/tvm_ffi/cpp/__init__.py b/python/tvm_ffi/cpp/__init__.py
index 8835e4a..e3deb12 100644
--- a/python/tvm_ffi/cpp/__init__.py
+++ b/python/tvm_ffi/cpp/__init__.py
@@ -16,4 +16,11 @@
# under the License.
"""C++ integration helpers for building and loading inline modules."""
-from .load_inline import build_inline, load_inline
+from .extension import build, build_inline, load, load_inline
+
+__all__ = [
+ "build",
+ "build_inline",
+ "load",
+ "load_inline",
+]
diff --git a/python/tvm_ffi/cpp/load_inline.py b/python/tvm_ffi/cpp/extension.py
similarity index 61%
rename from python/tvm_ffi/cpp/load_inline.py
rename to python/tvm_ffi/cpp/extension.py
index 50cc7f0..7d84476 100644
--- a/python/tvm_ffi/cpp/load_inline.py
+++ b/python/tvm_ffi/cpp/extension.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Build and load inline C++/CUDA sources into a tvm_ffi Module using Ninja."""
+"""Build and load C++/CUDA sources into a tvm_ffi Module using Ninja."""
from __future__ import annotations
@@ -25,6 +25,7 @@ import shutil
import subprocess
import sys
from collections.abc import Mapping, Sequence
+from contextlib import nullcontext
from pathlib import Path
from tvm_ffi.libinfo import find_dlpack_include_path, find_include_path,
find_libtvm_ffi
@@ -35,8 +36,10 @@ IS_WINDOWS = sys.platform == "win32"
def _hash_sources(
- cpp_source: str,
- cuda_source: str,
+ cpp_source: str | None,
+ cuda_source: str | None,
+ cpp_files: Sequence[str] | None,
+ cuda_files: Sequence[str] | None,
functions: Sequence[str] | Mapping[str, str],
extra_cflags: Sequence[str],
extra_cuda_cflags: Sequence[str],
@@ -45,23 +48,32 @@ def _hash_sources(
) -> str:
"""Generate a unique hash for the given sources and functions."""
m = hashlib.sha256()
- m.update(cpp_source.encode("utf-8"))
- m.update(cuda_source.encode("utf-8"))
+
+ def _maybe_hash_string(source: str | None) -> None:
+ if source is not None:
+ m.update(source.encode("utf-8"))
+
+ def _hash_sequence(seq: Sequence[str]) -> None:
+ for item in seq:
+ m.update(item.encode("utf-8"))
+
+ def _hash_mapping(mapping: Mapping[str, str]) -> None:
+ for key in sorted(mapping):
+ m.update(key.encode("utf-8"))
+ m.update(mapping[key].encode("utf-8"))
+
+ _maybe_hash_string(cpp_source)
+ _maybe_hash_string(cuda_source)
+ _hash_sequence(sorted(cpp_files or []))
+ _hash_sequence(sorted(cuda_files or []))
if isinstance(functions, Mapping):
- for name in sorted(functions):
- m.update(name.encode("utf-8"))
- m.update(functions[name].encode("utf-8"))
+ _hash_mapping(functions)
else:
- for name in sorted(functions):
- m.update(name.encode("utf-8"))
- for flag in extra_cflags:
- m.update(flag.encode("utf-8"))
- for flag in extra_cuda_cflags:
- m.update(flag.encode("utf-8"))
- for flag in extra_ldflags:
- m.update(flag.encode("utf-8"))
- for path in extra_include_paths:
- m.update(path.encode("utf-8"))
+ _hash_sequence(sorted(functions))
+ _hash_sequence(extra_cflags)
+ _hash_sequence(extra_cuda_cflags)
+ _hash_sequence(extra_ldflags)
+ _hash_sequence(extra_include_paths)
return m.hexdigest()[:16]
@@ -197,12 +209,13 @@ def _run_command_in_dev_prompt(
def _generate_ninja_build( # noqa: PLR0915
name: str,
- build_dir: str,
with_cuda: bool,
extra_cflags: Sequence[str],
extra_cuda_cflags: Sequence[str],
extra_ldflags: Sequence[str],
extra_include_paths: Sequence[str],
+ cpp_files: Sequence[str],
+ cuda_files: Sequence[str],
) -> str:
"""Generate the content of build.ninja for building the module."""
default_include_paths = [find_include_path(), find_dlpack_include_path()]
@@ -258,7 +271,7 @@ def _generate_ninja_build( # noqa: PLR0915
cuda_cflags.append("-I{}".format(path.replace(":", "$:")))
# flags
- ninja = []
+ ninja: list[str] = []
ninja.append("ninja_required_version = 1.3")
ninja.append("cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS
else "c++")))
ninja.append("cflags = {}".format(" ".join(cflags)))
@@ -296,20 +309,21 @@ def _generate_ninja_build( # noqa: PLR0915
ninja.append("")
# build targets
- ninja.append(
- "build main.o: compile {}".format(
- str((Path(build_dir) / "main.cpp").resolve()).replace(":", "$:")
- )
- )
- if with_cuda:
- ninja.append(
- "build cuda.o: compile_cuda {}".format(
- str((Path(build_dir) / "cuda.cu").resolve()).replace(":", "$:")
- )
- )
+ link_files: list[str] = []
+ for i, cpp_path in enumerate(sorted(cpp_files)):
+ obj_name = f"cpp_{i}.o"
+ ninja.append("build {}: compile {}".format(obj_name,
cpp_path.replace(":", "$:")))
+ link_files.append(obj_name)
+
+ for i, cuda_path in enumerate(sorted(cuda_files)):
+ obj_name = f"cuda_{i}.o"
+ ninja.append("build {}: compile_cuda {}".format(obj_name,
cuda_path.replace(":", "$:")))
+ link_files.append(obj_name)
+
# Use appropriate extension based on platform
ext = ".dll" if IS_WINDOWS else ".so"
- ninja.append("build {}{}: link main.o{}".format(name, ext, " cuda.o" if
with_cuda else ""))
+ link_name = " ".join(link_files)
+ ninja.append(f"build {name}{ext}: link {link_name}")
ninja.append("")
# default target
@@ -360,7 +374,82 @@ def _decorate_with_tvm_ffi(source: str, functions:
Mapping[str, str]) -> str:
return "\n".join(sources)
-def build_inline( # noqa: PLR0915, PLR0912
+def _str_seq2list(seq: Sequence[str] | str | None) -> list[str]:
+ if seq is None:
+ return []
+ elif isinstance(seq, str):
+ return [seq]
+ else:
+ return list(seq)
+
+
+def _build_impl(
+ name: str,
+ cpp_files: Sequence[str] | str | None,
+ cuda_files: Sequence[str] | str | None,
+ extra_cflags: Sequence[str] | None,
+ extra_cuda_cflags: Sequence[str] | None,
+ extra_ldflags: Sequence[str] | None,
+ extra_include_paths: Sequence[str] | None,
+ build_directory: str | None,
+ need_lock: bool = True,
+) -> str:
+ """Real implementation of build function."""
+ # need to resolve the path to make it unique
+ cpp_path_list = [str(Path(p).resolve()) for p in _str_seq2list(cpp_files)]
+ cuda_path_list = [str(Path(p).resolve()) for p in
_str_seq2list(cuda_files)]
+ with_cpp = bool(cpp_path_list)
+ with_cuda = bool(cuda_path_list)
+ assert with_cpp or with_cuda, "Either cpp_files or cuda_files must be
provided."
+
+ extra_ldflags_list = list(extra_ldflags) if extra_ldflags is not None else
[]
+ extra_cflags_list = list(extra_cflags) if extra_cflags is not None else []
+ extra_cuda_cflags_list = list(extra_cuda_cflags) if extra_cuda_cflags is
not None else []
+ extra_include_paths_list = list(extra_include_paths) if
extra_include_paths is not None else []
+
+ build_dir: Path
+ if build_directory is None:
+ cache_dir = os.environ.get("TVM_FFI_CACHE_DIR",
str(Path("~/.cache/tvm-ffi").expanduser()))
+ source_hash: str = _hash_sources(
+ None,
+ None,
+ cpp_path_list,
+ cuda_path_list,
+ {},
+ extra_cflags_list,
+ extra_cuda_cflags_list,
+ extra_ldflags_list,
+ extra_include_paths_list,
+ )
+ build_dir = Path(cache_dir).expanduser() / f"{name}_{source_hash}"
+ else:
+ build_dir = Path(build_directory).resolve()
+ build_dir.mkdir(parents=True, exist_ok=True)
+
+ # generate build.ninja
+ ninja_source = _generate_ninja_build(
+ name=name,
+ with_cuda=with_cuda,
+ extra_cflags=extra_cflags_list,
+ extra_cuda_cflags=extra_cuda_cflags_list,
+ extra_ldflags=extra_ldflags_list,
+ extra_include_paths=extra_include_paths_list,
+ cpp_files=cpp_path_list,
+ cuda_files=cuda_path_list,
+ )
+
+ # may not hold lock when build_directory is specified, prevent deadlock
+ with FileLock(str(build_dir / "lock")) if need_lock else nullcontext():
+ # write build.ninja if it does not already exist
+ _maybe_write(str(build_dir / "build.ninja"), ninja_source)
+ # build the module
+ build_ninja(str(build_dir))
+ # Use appropriate extension based on platform
+ ext = ".dll" if IS_WINDOWS else ".so"
+ return str((build_dir / f"{name}{ext}").resolve())
+
+
+def build_inline(
name: str,
*,
cpp_sources: Sequence[str] | str | None = None,
@@ -484,22 +573,12 @@ def build_inline( # noqa: PLR0915, PLR0912
torch.testing.assert_close(x + 1, y)
"""
- if cpp_sources is None:
- cpp_source_list: list[str] = []
- elif isinstance(cpp_sources, str):
- cpp_source_list = [cpp_sources]
- else:
- cpp_source_list = list(cpp_sources)
+ cpp_source_list = _str_seq2list(cpp_sources)
cpp_source = "\n".join(cpp_source_list)
with_cpp = bool(cpp_source_list)
del cpp_source_list
- if cuda_sources is None:
- cuda_source_list: list[str] = []
- elif isinstance(cuda_sources, str):
- cuda_source_list = [cuda_sources]
- else:
- cuda_source_list = list(cuda_sources)
+ cuda_source_list = _str_seq2list(cuda_sources)
cuda_source = "\n".join(cuda_source_list)
with_cuda = bool(cuda_source_list)
del cuda_source_list
@@ -533,6 +612,8 @@ def build_inline( # noqa: PLR0915, PLR0912
source_hash: str = _hash_sources(
cpp_source,
cuda_source,
+ None,
+ None,
function_map,
extra_cflags_list,
extra_cuda_cflags_list,
@@ -544,27 +625,26 @@ def build_inline( # noqa: PLR0915, PLR0912
build_dir = Path(build_directory).resolve()
build_dir.mkdir(parents=True, exist_ok=True)
- # generate build.ninja
- ninja_source = _generate_ninja_build(
- name=name,
- build_dir=str(build_dir),
- with_cuda=with_cuda,
- extra_cflags=extra_cflags_list,
- extra_cuda_cflags=extra_cuda_cflags_list,
- extra_ldflags=extra_ldflags_list,
- extra_include_paths=extra_include_paths_list,
- )
+ cpp_file = str((build_dir / "main.cpp").resolve())
+ cuda_file = str((build_dir / "cuda.cu").resolve())
+
with FileLock(str(build_dir / "lock")):
- # write source files and build.ninja if they do not already exist
- _maybe_write(str(build_dir / "main.cpp"), cpp_source)
+ # write source files if they do not already exist
+ _maybe_write(cpp_file, cpp_source)
if with_cuda:
- _maybe_write(str(build_dir / "cuda.cu"), cuda_source)
- _maybe_write(str(build_dir / "build.ninja"), ninja_source)
- # build the module
- build_ninja(str(build_dir))
- # Use appropriate extension based on platform
- ext = ".dll" if IS_WINDOWS else ".so"
- return str((build_dir / f"{name}{ext}").resolve())
+ _maybe_write(cuda_file, cuda_source)
+
+ return _build_impl(
+ name=name,
+ cpp_files=[cpp_file] if with_cpp else [],
+ cuda_files=[cuda_file] if with_cuda else [],
+ extra_cflags=extra_cflags_list,
+ extra_cuda_cflags=extra_cuda_cflags_list,
+ extra_ldflags=extra_ldflags_list,
+ extra_include_paths=extra_include_paths_list,
+ build_directory=str(build_dir),
+ need_lock=False, # already hold the lock
+ )
def load_inline(
@@ -702,3 +782,280 @@ def load_inline(
build_directory=build_directory,
)
)
+
+
+def build(
+ name: str,
+ *,
+ cpp_files: Sequence[str] | str | None = None,
+ cuda_files: Sequence[str] | str | None = None,
+ extra_cflags: Sequence[str] | None = None,
+ extra_cuda_cflags: Sequence[str] | None = None,
+ extra_ldflags: Sequence[str] | None = None,
+ extra_include_paths: Sequence[str] | None = None,
+ build_directory: str | None = None,
+) -> str:
+ """Compile and build a C++/CUDA module from source files.
+
+ This function compiles the given C++ and/or CUDA source files into a
shared library. Both ``cpp_files`` and
+ ``cuda_files`` are compiled to object files, and then linked together into
a shared library. It's possible to only
+ provide cpp_files or cuda_files. The path to the compiled shared library
is returned.
+
+ Note that this function does not automatically export functions to the tvm
ffi module. You need to
+ manually use the TVM FFI export macros (e.g.,
``TVM_FFI_DLL_EXPORT_TYPED_FUNC``) in your source files to export
+ functions. This gives you more control over which functions are exported
and how they are exported.
+
+ Extra compiler and linker flags can be provided via the ``extra_cflags``,
``extra_cuda_cflags``, and ``extra_ldflags``
+ parameters. The default flags are generally sufficient for most use cases,
but you may need to provide additional
+ flags for your specific use case.
+
+ The include dir of tvm ffi and dlpack are used by default for the compiler
to find the headers. Thus, you can
+ include any header from tvm ffi in your source files. You can also provide
additional include paths via the
+ ``extra_include_paths`` parameter and include custom headers in your
source code.
+
+ The compiled shared library is cached in a cache directory to avoid
recompilation. The `build_directory` parameter
+ is provided to specify the build directory. If not specified, a default
tvm ffi cache directory will be used.
+ The default cache directory can be specified via the `TVM_FFI_CACHE_DIR`
environment variable. If not specified,
+ the default cache directory is ``~/.cache/tvm-ffi``.
+
+ Parameters
+ ----------
+ name
+ The name of the tvm ffi module.
+ cpp_files
+ The C++ source files to compile. It can be a list of file paths or a
single file path. Both absolute and
+ relative paths are supported.
+ cuda_files
+ The CUDA source files to compile. It can be a list of file paths or a
single file path. Both absolute and
+ relative paths are supported.
+ extra_cflags
+ The extra compiler flags for C++ compilation.
+ The default flags are:
+
+ - On Linux/macOS: ['-std=c++17', '-fPIC', '-O2']
+ - On Windows: ['/std:c++17', '/MD', '/O2']
+
+ extra_cuda_cflags
+ The extra compiler flags for CUDA compilation.
+ The default flags are:
+
+ - ['-Xcompiler', '-fPIC', '-std=c++17', '-O2'] (Linux/macOS)
+ - ['-Xcompiler', '/std:c++17', '/O2'] (Windows)
+
+ extra_ldflags
+ The extra linker flags.
+ The default flags are:
+
+ - On Linux/macOS: ['-shared', '-L<tvm_ffi_lib_path>', '-ltvm_ffi']
+ - On Windows: ['/DLL', '/LIBPATH:<tvm_ffi_lib_path>',
'<tvm_ffi_lib_name>.lib']
+
+ extra_include_paths
+ The extra include paths for header files. Both absolute and relative
paths are supported.
+
+ build_directory
+ The build directory. If not specified, a default tvm ffi cache
directory will be used. By default, the
+ cache directory is ``~/.cache/tvm-ffi``. You can also set the
``TVM_FFI_CACHE_DIR`` environment variable to
+ specify the cache directory.
+
+ Returns
+ -------
+ lib_path: str
+ The path to the built shared library.
+
+ Example
+ -------
+
+ .. code-block:: python
+
+ import torch
+ from tvm_ffi import Module
+ import tvm_ffi.cpp
+
+ # Assume we have a C++ source file "my_ops.cpp" with the following
content:
+ # ```cpp
+ # #include <tvm/ffi/container/tensor.h>
+ # #include <tvm/ffi/dtype.h>
+ # #include <tvm/ffi/error.h>
+ # #include <tvm/ffi/extra/c_env_api.h>
+ # #include <tvm/ffi/function.h>
+ #
+ # void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
+ # TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
+ # DLDataType f32_dtype{kDLFloat, 32, 1};
+ # TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float
tensor";
+ # TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+ # TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float
tensor";
+ # TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the
same shape";
+ # for (int i = 0; i < x.size(0); ++i) {
+ # static_cast<float*>(y.data_ptr())[i] =
static_cast<float*>(x.data_ptr())[i] + 1;
+ # }
+ # }
+ #
+ # TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu, add_one_cpu);
+ # ```
+
+ # compile the cpp source file and get the library path
+ lib_path: str = tvm_ffi.cpp.build(
+ name='my_ops',
+ cpp_files='my_ops.cpp'
+ )
+
+ # load the module
+ mod: Module = tvm_ffi.load_module(lib_path)
+
+ # use the function from the loaded module
+ x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
+ y = torch.empty_like(x)
+ mod.add_one_cpu(x, y)
+ torch.testing.assert_close(x + 1, y)
+
+ """
+ return _build_impl(
+ name=name,
+ cpp_files=cpp_files,
+ cuda_files=cuda_files,
+ extra_cflags=extra_cflags,
+ extra_cuda_cflags=extra_cuda_cflags,
+ extra_ldflags=extra_ldflags,
+ extra_include_paths=extra_include_paths,
+ build_directory=build_directory,
+ need_lock=True,
+ )
+
+
+def load(
+ name: str,
+ *,
+ cpp_files: Sequence[str] | str | None = None,
+ cuda_files: Sequence[str] | str | None = None,
+ extra_cflags: Sequence[str] | None = None,
+ extra_cuda_cflags: Sequence[str] | None = None,
+ extra_ldflags: Sequence[str] | None = None,
+ extra_include_paths: Sequence[str] | None = None,
+ build_directory: str | None = None,
+) -> Module:
+ """Compile, build and load a C++/CUDA module from source files.
+
+ This function compiles the given C++ and/or CUDA source files into a
shared library and loads it as a tvm ffi
+ module. Both ``cpp_files`` and ``cuda_files`` are compiled to object
files, and then linked together into a shared
+ library. It's possible to only provide cpp_files or cuda_files.
+
+ Note that this function does not automatically export functions to the tvm
ffi module. You need to
+ manually use the TVM FFI export macros (e.g.,
``TVM_FFI_DLL_EXPORT_TYPED_FUNC``) in your source files to export
+ functions. This gives you more control over which functions are exported
and how they are exported.
+
+ Extra compiler and linker flags can be provided via the ``extra_cflags``,
``extra_cuda_cflags``, and ``extra_ldflags``
+ parameters. The default flags are generally sufficient for most use cases,
but you may need to provide additional
+ flags for your specific use case.
+
+ The include dir of tvm ffi and dlpack are used by default for the compiler
to find the headers. Thus, you can
+ include any header from tvm ffi in your source files. You can also provide
additional include paths via the
+ ``extra_include_paths`` parameter and include custom headers in your
source code.
+
+ The compiled shared library is cached in a cache directory to avoid
recompilation. The `build_directory` parameter
+ is provided to specify the build directory. If not specified, a default
tvm ffi cache directory will be used.
+ The default cache directory can be specified via the `TVM_FFI_CACHE_DIR`
environment variable. If not specified,
+ the default cache directory is ``~/.cache/tvm-ffi``.
+
+ Parameters
+ ----------
+ name: str
+ The name of the tvm ffi module.
+ cpp_files: Sequence[str] | str, optional
+ The C++ source files to compile. It can be a list of file paths or a
single file path. Both absolute and
+ relative paths are supported.
+ cuda_files: Sequence[str] | str, optional
+ The CUDA source files to compile. It can be a list of file paths or a
single file path. Both absolute and
+ relative paths are supported.
+ extra_cflags: Sequence[str], optional
+ The extra compiler flags for C++ compilation.
+ The default flags are:
+
+ - On Linux/macOS: ['-std=c++17', '-fPIC', '-O2']
+ - On Windows: ['/std:c++17', '/MD', '/O2']
+
+ extra_cuda_cflags: Sequence[str], optional
+ The extra compiler flags for CUDA compilation.
+ The default flags are:
+
+ - ['-Xcompiler', '-fPIC', '-std=c++17', '-O2'] (Linux/macOS)
+ - ['-Xcompiler', '/std:c++17', '/O2'] (Windows)
+
+ extra_ldflags: Sequence[str], optional
+ The extra linker flags.
+ The default flags are:
+
+ - On Linux/macOS: ['-shared', '-L<tvm_ffi_lib_path>', '-ltvm_ffi']
+ - On Windows: ['/DLL', '/LIBPATH:<tvm_ffi_lib_path>',
'<tvm_ffi_lib_name>.lib']
+
+ extra_include_paths: Sequence[str], optional
+ The extra include paths for header files. Both absolute and relative
paths are supported.
+
+ build_directory: str, optional
+ The build directory. If not specified, a default tvm ffi cache
directory will be used. By default, the
+ cache directory is ``~/.cache/tvm-ffi``. You can also set the
``TVM_FFI_CACHE_DIR`` environment variable to
+ specify the cache directory.
+
+ Returns
+ -------
+ mod: Module
+ The loaded tvm ffi module.
+
+
+ Example
+ -------
+
+ .. code-block:: python
+
+ import torch
+ from tvm_ffi import Module
+ import tvm_ffi.cpp
+
+ # Assume we have a C++ source file "my_ops.cpp" with the following
content:
+ # ```cpp
+ # #include <tvm/ffi/container/tensor.h>
+ # #include <tvm/ffi/dtype.h>
+ # #include <tvm/ffi/error.h>
+ # #include <tvm/ffi/extra/c_env_api.h>
+ # #include <tvm/ffi/function.h>
+ #
+ # void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
+ # TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
+ # DLDataType f32_dtype{kDLFloat, 32, 1};
+ # TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float
tensor";
+ # TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+ # TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float
tensor";
+ # TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the
same shape";
+ # for (int i = 0; i < x.size(0); ++i) {
+ # static_cast<float*>(y.data_ptr())[i] =
static_cast<float*>(x.data_ptr())[i] + 1;
+ # }
+ # }
+ #
+ # TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu, add_one_cpu);
+ # ```
+
+ # compile the cpp source file and load the module
+ mod: Module = tvm_ffi.cpp.load(
+ name='my_ops',
+ cpp_files='my_ops.cpp'
+ )
+
+ # use the function from the loaded module
+ x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
+ y = torch.empty_like(x)
+ mod.add_one_cpu(x, y)
+ torch.testing.assert_close(x + 1, y)
+
+ """
+ return load_module(
+ build(
+ name=name,
+ cpp_files=cpp_files,
+ cuda_files=cuda_files,
+ extra_cflags=extra_cflags,
+ extra_cuda_cflags=extra_cuda_cflags,
+ extra_ldflags=extra_ldflags,
+ extra_include_paths=extra_include_paths,
+ build_directory=build_directory,
+ )
+ )
diff --git a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
index 288b3ed..4be433f 100644
--- a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
@@ -668,7 +668,7 @@ def main() -> None: # noqa: PLR0912, PLR0915
"""Build the torch c dlpack extension."""
# we need to set the following env to avoid tvm_ffi to build the torch
c-dlpack addon during importing
os.environ["TVM_FFI_DISABLE_TORCH_C_DLPACK"] = "1"
- from tvm_ffi.cpp.load_inline import build_ninja # noqa: PLC0415
+ from tvm_ffi.cpp.extension import build_ninja # noqa: PLC0415
from tvm_ffi.utils.lockfile import FileLock # noqa: PLC0415
parser = argparse.ArgumentParser(
diff --git a/tests/python/test_build.cc b/tests/python/test_build.cc
new file mode 100644
index 0000000..e25e8cf
--- /dev/null
+++ b/tests/python/test_build.cc
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "test_build.h"
+
+#include <tvm/ffi/container/tensor.h>
+#include <tvm/ffi/dtype.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/function.h>
+
+void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
+ TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
+ DLDataType f32_dtype{kDLFloat, 32, 1};
+ TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float tensor";
+ TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+ TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float tensor";
+ TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the same shape";
+ for (int i = 0; i < x.size(0); ++i) {
+ static_cast<float*>(y.data_ptr())[i] =
static_cast<float*>(x.data_ptr())[i] + 1;
+ }
+}
+
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cpu, add_one_cpu);
diff --git a/tests/python/test_build.h b/tests/python/test_build.h
new file mode 100644
index 0000000..ff515a6
--- /dev/null
+++ b/tests/python/test_build.h
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_FFI_TEST_BUILD_H_
+#define TVM_FFI_TEST_BUILD_H_
+#include <tvm/ffi/container/tensor.h>
+
+void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y);
+
+#endif // TVM_FFI_TEST_BUILD_H_
diff --git a/python/tvm_ffi/cpp/__init__.py b/tests/python/test_build.py
similarity index 57%
copy from python/tvm_ffi/cpp/__init__.py
copy to tests/python/test_build.py
index 8835e4a..30153dc 100644
--- a/python/tvm_ffi/cpp/__init__.py
+++ b/tests/python/test_build.py
@@ -14,6 +14,28 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""C++ integration helpers for building and loading inline modules."""
+import pathlib
-from .load_inline import build_inline, load_inline
+import numpy
+import pytest
+import tvm_ffi.cpp
+from tvm_ffi.module import Module
+
+
+def test_build_cpp() -> None:
+ cpp_path = pathlib.Path(__file__).parent.resolve() / "test_build.cc"
+ output_lib_path = tvm_ffi.cpp.build(
+ name="hello",
+ cpp_files=[str(cpp_path)],
+ )
+
+ mod: Module = tvm_ffi.load_module(output_lib_path)
+
+ x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32)
+ y = numpy.empty_like(x)
+ mod.add_one_cpu(x, y)
+ numpy.testing.assert_equal(x + 1, y)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])