This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 4fcf94f  [Utils] Add `build_inline` utility (#73)
4fcf94f is described below

commit 4fcf94f6e2dcce9901e0e42c30c7d4d57487619d
Author: Yaoyao Ding <[email protected]>
AuthorDate: Mon Sep 29 18:52:53 2025 -0400

    [Utils] Add `build_inline` utility (#73)
    
    This PR adds the `tvm_ffi.cpp.build_inline` utility function.
    
    Example:
    ```python
    import torch
    from tvm_ffi import Module
    import tvm_ffi.cpp
    
    # define the cpp source code
    cpp_source = '''
         void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
           // implementation of a library function
           TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
           DLDataType f32_dtype{kDLFloat, 32, 1};
           TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
           TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
           TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
           TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the 
same shape";
           for (int i = 0; i < x->shape[0]; ++i) {
             static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] 
+ 1;
           }
         }
    '''
    
    # compile the cpp source code and load the module
    lib_path: str = tvm_ffi.cpp.load_inline(
        name='hello',
        cpp_sources=cpp_source,
        functions='add_one_cpu',
        # build_directory='./add_one/',  # can we optionally specify the build 
directory
    )
    
    # load the module
    mod: Module = tvm_ffi.load_module(lib_path)
    
    # use the function from the loaded module to perform
    x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
    y = torch.empty_like(x)
    mod.add_one_cpu(x, y)
    torch.testing.assert_close(x + 1, y)
    ```
    
    The `build_inline` function is similar to `tvm_ffi.cpp.load_inline` but
    only build the module without loading it after build. It returns the
    path to the build shared library (e.g.,
    `~/.cache/tvm-ffi/hello_95b50659cc3e9b6d/hello.so`).
    
    Minor:
    Also add `autodoc_typehints = 'description'` following
    https://github.com/apache/tvm-ffi/pull/52#discussion_r2374528315.
---
 docs/conf.py                      |   1 +
 docs/guides/python_guide.md       |   3 +
 docs/reference/python/index.rst   |   1 +
 python/tvm_ffi/cpp/__init__.py    |   2 +-
 python/tvm_ffi/cpp/load_inline.py | 157 +++++++++++++++++++++++++++++++++++---
 tests/python/test_build_inline.py |  52 +++++++++++++
 6 files changed, 206 insertions(+), 10 deletions(-)

diff --git a/docs/conf.py b/docs/conf.py
index 50495de..0d4c3f1 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -177,6 +177,7 @@ autodoc_default_options = {
     "inherited-members": False,
     "member-order": "bysource",
 }
+autodoc_typehints = "description"
 
 # -- Other Options --------------------------------------------------------
 
diff --git a/docs/guides/python_guide.md b/docs/guides/python_guide.md
index 0fd35d8..de2adfa 100644
--- a/docs/guides/python_guide.md
+++ b/docs/guides/python_guide.md
@@ -181,6 +181,9 @@ The above code defines a C++ function `add_one_cpu` in 
Python script, compiles i
 {py:class}`tvm_ffi.Module` object via {py:func}`tvm_ffi.cpp.load_inline`. You 
can then call the function `add_one_cpu`
 from the module as usual.
 
+We can also use {py:func}`tvm_ffi.cpp.build_inline` to build the inline module 
without loading it. The built shared library is returned
+and can be loaded via {py:func}`tvm_ffi.load_module`.
+
 ## Error Handling
 
 An FFI function may raise an error. In such cases, the Python package will 
automatically
diff --git a/docs/reference/python/index.rst b/docs/reference/python/index.rst
index 8b478e8..7cae49d 100644
--- a/docs/reference/python/index.rst
+++ b/docs/reference/python/index.rst
@@ -87,6 +87,7 @@ C++ integration helpers for building and loading inline 
modules.
   :toctree: cpp/generated/
 
   cpp.load_inline
+  cpp.build_inline
 
 
 .. (Experimental) Dataclasses
diff --git a/python/tvm_ffi/cpp/__init__.py b/python/tvm_ffi/cpp/__init__.py
index ede2b54..8835e4a 100644
--- a/python/tvm_ffi/cpp/__init__.py
+++ b/python/tvm_ffi/cpp/__init__.py
@@ -16,4 +16,4 @@
 # under the License.
 """C++ integration helpers for building and loading inline modules."""
 
-from .load_inline import load_inline
+from .load_inline import build_inline, load_inline
diff --git a/python/tvm_ffi/cpp/load_inline.py 
b/python/tvm_ffi/cpp/load_inline.py
index d7a5c14..3d1d5b5 100644
--- a/python/tvm_ffi/cpp/load_inline.py
+++ b/python/tvm_ffi/cpp/load_inline.py
@@ -360,7 +360,7 @@ def _decorate_with_tvm_ffi(source: str, functions: 
Mapping[str, str]) -> str:
     return "\n".join(sources)
 
 
-def load_inline(  # noqa: PLR0912, PLR0915
+def build_inline(  # noqa: PLR0915, PLR0912
     name: str,
     *,
     cpp_sources: Sequence[str] | str | None = None,
@@ -371,12 +371,12 @@ def load_inline(  # noqa: PLR0912, PLR0915
     extra_ldflags: Sequence[str] | None = None,
     extra_include_paths: Sequence[str] | None = None,
     build_directory: str | None = None,
-) -> Module:
-    """Compile and load a C++/CUDA module from inline source code.
+) -> str:
+    """Compile and build a C++/CUDA module from inline source code.
 
     This function compiles the given C++ and/or CUDA source code into a shared 
library. Both ``cpp_sources`` and
     ``cuda_sources`` are compiled to an object file, and then linked together 
into a shared library. It's possible to only
-    provide cpp_sources or cuda_sources.
+    provide cpp_sources or cuda_sources. The path to the compiled shared 
library is returned.
 
     The ``functions`` parameter is used to specify which functions in the 
source code should be exported to the tvm ffi
     module. It can be a mapping, a sequence, or a single string. When a 
mapping is given, the keys are the names of the
@@ -439,9 +439,8 @@ def load_inline(  # noqa: PLR0912, PLR0915
 
     Returns
     -------
-    mod: Module
-        The loaded tvm ffi module.
-
+    lib_path: str
+        The path to the built shared library.
 
     Example
     -------
@@ -469,12 +468,15 @@ def load_inline(  # noqa: PLR0912, PLR0915
         '''
 
         # compile the cpp source code and load the module
-        mod: Module = tvm_ffi.cpp.load_inline(
+        lib_path: str = tvm_ffi.cpp.build_inline(
             name='hello',
             cpp_sources=cpp_source,
             functions='add_one_cpu'
         )
 
+        # load the module
+        mod: Module = tvm_ffi.load_module(lib_path)
+
         # use the function from the loaded module to perform
         x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
         y = torch.empty_like(x)
@@ -562,4 +564,141 @@ def load_inline(  # noqa: PLR0912, PLR0915
         _build_ninja(str(build_dir))
         # Use appropriate extension based on platform
         ext = ".dll" if IS_WINDOWS else ".so"
-        return load_module(str((build_dir / f"{name}{ext}").resolve()))
+        return str((build_dir / f"{name}{ext}").resolve())
+
+
+def load_inline(
+    name: str,
+    *,
+    cpp_sources: Sequence[str] | str | None = None,
+    cuda_sources: Sequence[str] | str | None = None,
+    functions: Mapping[str, str] | Sequence[str] | str | None = None,
+    extra_cflags: Sequence[str] | None = None,
+    extra_cuda_cflags: Sequence[str] | None = None,
+    extra_ldflags: Sequence[str] | None = None,
+    extra_include_paths: Sequence[str] | None = None,
+    build_directory: str | None = None,
+) -> Module:
+    """Compile, build and load a C++/CUDA module from inline source code.
+
+    This function compiles the given C++ and/or CUDA source code into a shared 
library. Both ``cpp_sources`` and
+    ``cuda_sources`` are compiled to an object file, and then linked together 
into a shared library. It's possible to only
+    provide cpp_sources or cuda_sources.
+
+    The ``functions`` parameter is used to specify which functions in the 
source code should be exported to the tvm ffi
+    module. It can be a mapping, a sequence, or a single string. When a 
mapping is given, the keys are the names of the
+    exported functions, and the values are docstrings for the functions. When 
a sequence of string is given, they are
+    the function names needed to be exported, and the docstrings are set to 
empty strings. A single function name can
+    also be given as a string, indicating that only one function is to be 
exported.
+
+    Extra compiler and linker flags can be provided via the ``extra_cflags``, 
``extra_cuda_cflags``, and ``extra_ldflags``
+    parameters. The default flags are generally sufficient for most use cases, 
but you may need to provide additional
+    flags for your specific use case.
+
+    The include dir of tvm ffi and dlpack are used by default for the compiler 
to find the headers. Thus, you can
+    include any header from tvm ffi in your source code. You can also provide 
additional include paths via the
+    ``extra_include_paths`` parameter and include custom headers in your 
source code.
+
+    The compiled shared library is cached in a cache directory to avoid 
recompilation. The `build_directory` parameter
+    is provided to specify the build directory. If not specified, a default 
tvm ffi cache directory will be used.
+    The default cache directory can be specified via the `TVM_FFI_CACHE_DIR` 
environment variable. If not specified,
+    the default cache directory is ``~/.cache/tvm-ffi``.
+
+    Parameters
+    ----------
+    name: str
+        The name of the tvm ffi module.
+    cpp_sources: Sequence[str] | str, optional
+        The C++ source code. It can be a list of sources or a single source.
+    cuda_sources: Sequence[str] | str, optional
+        The CUDA source code. It can be a list of sources or a single source.
+    functions: Mapping[str, str] | Sequence[str] | str, optional
+        The functions in cpp_sources or cuda_source that will be exported to 
the tvm ffi module. When a mapping is
+        given, the keys are the names of the exported functions, and the 
values are docstrings for the functions. When
+        a sequence or a single string is given, they are the functions needed 
to be exported, and the docstrings are set
+        to empty strings. A single function name can also be given as a 
string. When cpp_sources is given, the functions
+        must be declared (not necessarily defined) in the cpp_sources. When 
cpp_sources is not given, the functions
+        must be defined in the cuda_sources. If not specified, no function 
will be exported.
+    extra_cflags: Sequence[str], optional
+        The extra compiler flags for C++ compilation.
+        The default flags are:
+
+        - On Linux/macOS: ['-std=c++17', '-fPIC', '-O2']
+        - On Windows: ['/std:c++17', '/O2']
+
+    extra_cuda_cflags: Sequence[str], optional
+        The extra compiler flags for CUDA compilation.
+
+    extra_ldflags: Sequence[str], optional
+        The extra linker flags.
+        The default flags are:
+
+        - On Linux/macOS: ['-shared']
+        - On Windows: ['/DLL']
+
+    extra_include_paths: Sequence[str], optional
+        The extra include paths.
+
+    build_directory: str, optional
+        The build directory. If not specified, a default tvm ffi cache 
directory will be used. By default, the
+        cache directory is ``~/.cache/tvm-ffi``. You can also set the 
``TVM_FFI_CACHE_DIR`` environment variable to
+        specify the cache directory.
+
+    Returns
+    -------
+    mod: Module
+        The loaded tvm ffi module.
+
+
+    Example
+    -------
+
+    .. code-block:: python
+
+        import torch
+        from tvm_ffi import Module
+        import tvm_ffi.cpp
+
+        # define the cpp source code
+        cpp_source = '''
+             void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+               // implementation of a library function
+               TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+               DLDataType f32_dtype{kDLFloat, 32, 1};
+               TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
+               TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+               TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
+               TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must 
have the same shape";
+               for (int i = 0; i < x->shape[0]; ++i) {
+                 static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+               }
+             }
+        '''
+
+        # compile the cpp source code and load the module
+        mod: Module = tvm_ffi.cpp.load_inline(
+            name='hello',
+            cpp_sources=cpp_source,
+            functions='add_one_cpu'
+        )
+
+        # use the function from the loaded module to perform
+        x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
+        y = torch.empty_like(x)
+        mod.add_one_cpu(x, y)
+        torch.testing.assert_close(x + 1, y)
+
+    """
+    return load_module(
+        build_inline(
+            name=name,
+            cpp_sources=cpp_sources,
+            cuda_sources=cuda_sources,
+            functions=functions,
+            extra_cflags=extra_cflags,
+            extra_cuda_cflags=extra_cuda_cflags,
+            extra_ldflags=extra_ldflags,
+            extra_include_paths=extra_include_paths,
+            build_directory=build_directory,
+        )
+    )
diff --git a/tests/python/test_build_inline.py 
b/tests/python/test_build_inline.py
new file mode 100644
index 0000000..d436271
--- /dev/null
+++ b/tests/python/test_build_inline.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy
+import pytest
+import tvm_ffi.cpp
+from tvm_ffi.module import Module
+
+
+def test_build_inline_cpp() -> None:
+    output_lib_path = tvm_ffi.cpp.build_inline(
+        name="hello",
+        cpp_sources=r"""
+            void add_one_cpu(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+              // implementation of a library function
+              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              DLDataType f32_dtype{kDLFloat, 32, 1};
+              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
+              for (int i = 0; i < x->shape[0]; ++i) {
+                static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+              }
+            }
+        """,
+        functions=["add_one_cpu"],
+    )
+
+    mod: Module = tvm_ffi.load_module(output_lib_path)
+
+    x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32)
+    y = numpy.empty_like(x)
+    mod.add_one_cpu(x, y)
+    numpy.testing.assert_equal(x + 1, y)
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])

Reply via email to