This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/dev by this push:
new 742b16e add documentation for tvm_ffi.cpp.load_inline (#12)
742b16e is described below
commit 742b16e5c71cf388e3bf8834763755448f3f0033
Author: Yaoyao Ding <[email protected]>
AuthorDate: Sun Sep 14 15:22:36 2025 -0400
add documentation for tvm_ffi.cpp.load_inline (#12)
Signed-off-by: Yaoyao Ding <[email protected]>
---
docs/guides/python_guide.md | 41 ++++++++++++++++++
docs/reference/python/cpp/index.rst | 10 +++++
docs/reference/python/index.rst | 9 ++++
examples/inline_module/main.py | 6 +--
python/tvm_ffi/cpp/load_inline.py | 84 +++++++++++++++++++++++++++----------
tests/python/test_load_inline.py | 18 ++++----
6 files changed, 135 insertions(+), 33 deletions(-)
diff --git a/docs/guides/python_guide.md b/docs/guides/python_guide.md
index 0ab56eb..fdf03a5 100644
--- a/docs/guides/python_guide.md
+++ b/docs/guides/python_guide.md
@@ -139,6 +139,47 @@ assert map_obj["b"] == 2
When container values are returned from FFI functions, they are also stored in
these
types respectively.
+## Inline Module
+
+You can also load a _inline module_ where the C++/CUDA code is directly
embedded in the Python script and then compiled
+on the fly. For example, we can define a simple kernel that adds one to each
element of an array as follows:
+
+```python
+import torch
+from tvm_ffi import Module
+import tvm_ffi.cpp
+
+# define the cpp source code
+cpp_source = '''
+ void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ // implementation of a library function
+ TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+ DLDataType f32_dtype{kDLFloat, 32, 1};
+ TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
+ TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+ TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
+ TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the
same shape";
+ for (int i = 0; i < x->shape[0]; ++i) {
+ static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
+ }
+ }
+'''
+
+# compile the cpp source code and load the module
+mod: Module = tvm_ffi.cpp.load_inline(
+ name='hello', cpp_sources=cpp_source, functions='add_one_cpu'
+)
+
+# use the function from the loaded module to perform
+x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
+y = torch.empty_like(x)
+mod.add_one_cpu(x, y)
+torch.testing.assert_close(x + 1, y)
+```
+
+The above code defines a C++ function `add_one_cpu` in Python script, compiles
it on the fly and then loads the compiled
+{py:class}`tvm_ffi.Module` object via {py:func}`tvm_ffi.cpp.load_inline`. You
can then call the function `add_one_cpu`
+from the module as usual.
## Error Handling
diff --git a/docs/reference/python/cpp/index.rst
b/docs/reference/python/cpp/index.rst
new file mode 100644
index 0000000..8737633
--- /dev/null
+++ b/docs/reference/python/cpp/index.rst
@@ -0,0 +1,10 @@
+tvm_ffi.cpp
+-----------
+
+.. automodule:: tvm_ffi.cpp
+ :no-members:
+
+.. autosummary::
+ :toctree: generated/
+
+ load_inline
diff --git a/docs/reference/python/index.rst b/docs/reference/python/index.rst
index b357420..482c19d 100644
--- a/docs/reference/python/index.rst
+++ b/docs/reference/python/index.rst
@@ -66,3 +66,12 @@ Containers
Array
Map
+
+
+Utility
+-------
+
+.. toctree::
+ :maxdepth: 1
+
+ cpp/index.rst
diff --git a/examples/inline_module/main.py b/examples/inline_module/main.py
index da94fb5..98b939e 100644
--- a/examples/inline_module/main.py
+++ b/examples/inline_module/main.py
@@ -25,7 +25,7 @@ def main():
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
- void add_one_cpu(DLTensor* x, DLTensor* y) {
+ void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -38,7 +38,7 @@ def main():
}
}
- void add_one_cuda(DLTensor* x, DLTensor* y);
+ void add_one_cuda(tvm::ffi::Tensor x, tvm::ffi::Tensor y);
""",
cuda_sources=r"""
__global__ void AddOneKernel(float* x, float* y, int n) {
@@ -48,7 +48,7 @@ def main():
}
}
- void add_one_cuda(DLTensor* x, DLTensor* y) {
+ void add_one_cuda(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
diff --git a/python/tvm_ffi/cpp/load_inline.py
b/python/tvm_ffi/cpp/load_inline.py
index c836205..3c3c8d3 100644
--- a/python/tvm_ffi/cpp/load_inline.py
+++ b/python/tvm_ffi/cpp/load_inline.py
@@ -263,6 +263,7 @@ def _build_ninja(build_dir: str) -> None:
def _decorate_with_tvm_ffi(source: str, functions: Mapping[str, str]) -> str:
"""Decorate the given source code with TVM FFI export macros."""
sources = [
+ "#include <tvm/ffi/container/tensor.h>",
"#include <tvm/ffi/dtype.h>",
"#include <tvm/ffi/error.h>",
"#include <tvm/ffi/extra/c_env_api.h>",
@@ -292,33 +293,34 @@ def load_inline(
extra_include_paths: Sequence[str] | None = None,
build_directory: Optional[str] = None,
) -> Module:
- """Compile and load a C++/CUDA tvm ffi module from inline source code.
+ """Compile and load a C++/CUDA module from inline source code.
- This function compiles the given C++ and/or CUDA source code into a shared
library. Both cpp_sources and
- cuda_sources are compiled to an object file, and then linked together into
a shared library. It's possible to only
+ This function compiles the given C++ and/or CUDA source code into a shared
library. Both ``cpp_sources`` and
+ ``cuda_sources`` are compiled to an object file, and then linked together
into a shared library. It's possible to only
provide cpp_sources or cuda_sources.
- The `functions` parameter is used to specify which functions in the source
code should be exported to the tvm ffi module.
- It can be a mapping, a sequence, or a single string. When a mapping is
given, the keys are the names of the exported
- functions, and the values are docstrings for the functions. When a
sequence or a single string is given, they are the
- functions needed to be exported, and the docstrings are set to empty
strings. A single function name can also be given
- as a string, indicating that only one function is to be exported.
+ The ``functions`` parameter is used to specify which functions in the
source code should be exported to the tvm ffi
+ module. It can be a mapping, a sequence, or a single string. When a
mapping is given, the keys are the names of the
+ exported functions, and the values are docstrings for the functions. When
a sequence of string is given, they are
+ the function names needed to be exported, and the docstrings are set to
empty strings. A single function name can
+ also be given as a string, indicating that only one function is to be
exported.
- Extra compiler and linker flags can be provided via the `extra_cflags`,
`extra_cuda_cflags`, and `extra_ldflags`
+ Extra compiler and linker flags can be provided via the ``extra_cflags``,
``extra_cuda_cflags``, and ``extra_ldflags``
parameters. The default flags are generally sufficient for most use cases,
but you may need to provide additional
flags for your specific use case.
- The include dir of tvm ffi and dlpack are used by default for linker to
find the headers. Thus, you can include
- any header from tvm ffi and dlpack in your source code. You can also
provide additional include paths via the
- `extra_include_paths` parameter and include custom headers in your source
code.
+ The include dir of tvm ffi and dlpack are used by default for the compiler
to find the headers. Thus, you can
+ include any header from tvm ffi in your source code. You can also provide
additional include paths via the
+ ``extra_include_paths`` parameter and include custom headers in your
source code.
The compiled shared library is cached in a cache directory to avoid
recompilation. The `build_directory` parameter
is provided to specify the build directory. If not specified, a default
tvm ffi cache directory will be used.
The default cache directory can be specified via the `TVM_FFI_CACHE_DIR`
environment variable. If not specified,
- the default cache directory is `~/.cache/tvm-ffi`.
+ the default cache directory is ``~/.cache/tvm-ffi``.
Parameters
----------
+
name: str
The name of the tvm ffi module.
cpp_sources: Sequence[str] | str, optional
@@ -335,31 +337,71 @@ def load_inline(
extra_cflags: Sequence[str], optional
The extra compiler flags for C++ compilation.
The default flags are:
+
- On Linux/macOS: ['-std=c++17', '-fPIC', '-O2']
- - On Windows: ['/std:c++17']
- extra_cuda_cflags:
+ - On Windows: ['/std:c++17', '/O2']
+
+ extra_cuda_cflags: Sequence[str], optional
The extra compiler flags for CUDA compilation.
- The default flags are:
- - On Linux/macOS: ['-Xcompiler', '-fPIC', '-std=c++17', '-O2']
- - On Windows: ['-Xcompiler', '/std:c++17', '/O2']
+
extra_ldflags: Sequence[str], optional
The extra linker flags.
The default flags are:
+
- On Linux/macOS: ['-shared']
- On Windows: ['/DLL']
+
extra_include_paths: Sequence[str], optional
The extra include paths.
- The default include paths are:
- - The include path of tvm ffi
+
build_directory: str, optional
The build directory. If not specified, a default tvm ffi cache
directory will be used. By default, the
- cache directory is `~/.cache/tvm-ffi`. You can also set the
`TVM_FFI_CACHE_DIR` environment variable to
+ cache directory is ``~/.cache/tvm-ffi``. You can also set the
``TVM_FFI_CACHE_DIR`` environment variable to
specify the cache directory.
Returns
-------
mod: Module
The loaded tvm ffi module.
+
+
+ Example
+ -------
+
+ .. code-block:: python
+
+ import torch
+ from tvm_ffi import Module
+ import tvm_ffi.cpp
+
+ # define the cpp source code
+ cpp_source = '''
+ void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+ // implementation of a library function
+ TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+ DLDataType f32_dtype{kDLFloat, 32, 1};
+ TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float
tensor";
+ TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+ TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float
tensor";
+ TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must
have the same shape";
+ for (int i = 0; i < x->shape[0]; ++i) {
+ static_cast<float*>(y->data)[i] =
static_cast<float*>(x->data)[i] + 1;
+ }
+ }
+ '''
+
+ # compile the cpp source code and load the module
+ mod: Module = tvm_ffi.cpp.load_inline(
+ name='hello',
+ cpp_sources=cpp_source,
+ functions='add_one_cpu'
+ )
+
+ # use the function from the loaded module to perform
+ x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
+ y = torch.empty_like(x)
+ mod.add_one_cpu(x, y)
+ torch.testing.assert_close(x + 1, y)
"""
if cpp_sources is None:
cpp_sources = []
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index 7ecc100..d72bfa7 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -34,7 +34,7 @@ def test_load_inline_cpp():
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
- void add_one_cpu(DLTensor* x, DLTensor* y) {
+ void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -61,7 +61,7 @@ def test_load_inline_cpp_with_docstrings():
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
- void add_one_cpu(DLTensor* x, DLTensor* y) {
+ void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -89,7 +89,7 @@ def test_load_inline_cpp_multiple_sources():
name="hello",
cpp_sources=[
r"""
- void add_one_cpu(DLTensor* x, DLTensor* y) {
+ void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -103,7 +103,7 @@ def test_load_inline_cpp_multiple_sources():
}
""",
r"""
- void add_two_cpu(DLTensor* x, DLTensor* y) {
+ void add_two_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -131,7 +131,7 @@ def test_load_inline_cpp_build_dir():
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
- void add_one_cpu(DLTensor* x, DLTensor* y) {
+ void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -168,7 +168,7 @@ def test_load_inline_cuda():
}
}
- void add_one_cuda(DLTensor* x, DLTensor* y) {
+ void add_one_cuda(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -249,7 +249,7 @@ def test_load_inline_both():
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
- void add_one_cpu(DLTensor* x, DLTensor* y) {
+ void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -262,7 +262,7 @@ def test_load_inline_both():
}
}
- void add_one_cuda(DLTensor* x, DLTensor* y);
+ void add_one_cuda(tvm::ffi::Tensor x, tvm::ffi::Tensor y);
""",
cuda_sources=r"""
__global__ void AddOneKernel(float* x, float* y, int n) {
@@ -272,7 +272,7 @@ def test_load_inline_both():
}
}
- void add_one_cuda(DLTensor* x, DLTensor* y) {
+ void add_one_cuda(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};