This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 86b391a4b6 [FFI] Support inline module (#18271)
86b391a4b6 is described below
commit 86b391a4b6507f681d68a3187f0ae4da65986ffb
Author: Yaoyao Ding <[email protected]>
AuthorDate: Fri Sep 5 17:09:58 2025 -0400
[FFI] Support inline module (#18271)
This PR adds initial support for load_inline in tvm_ffi
---
ffi/examples/inline_module/main.py | 86 ++++++++
ffi/python/tvm_ffi/cpp/__init__.py | 18 ++
ffi/python/tvm_ffi/cpp/load_inline.py | 382 ++++++++++++++++++++++++++++++++++
ffi/python/tvm_ffi/utils/__init__.py | 18 ++
ffi/python/tvm_ffi/utils/lockfile.py | 113 ++++++++++
ffi/tests/python/test_load_inline.py | 161 ++++++++++++++
6 files changed, 778 insertions(+)
diff --git a/ffi/examples/inline_module/main.py
b/ffi/examples/inline_module/main.py
new file mode 100644
index 0000000000..574d55c678
--- /dev/null
+++ b/ffi/examples/inline_module/main.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import torch
+import tvm_ffi.cpp
+from tvm_ffi.module import Module
+
+
+def main():
+ mod: Module = tvm_ffi.cpp.load_inline(
+ name="hello",
+ cpp_source=r"""
+ void AddOne(DLTensor* x, DLTensor* y) {
+ // implementation of a library function
+ TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+ DLDataType f32_dtype{kDLFloat, 32, 1};
+ TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float
tensor";
+ TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+ TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float
tensor";
+ TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have
the same shape";
+ for (int i = 0; i < x->shape[0]; ++i) {
+ static_cast<float*>(y->data)[i] =
static_cast<float*>(x->data)[i] + 1;
+ }
+ }
+ """,
+ cuda_source=r"""
+ __global__ void AddOneKernel(float* x, float* y, int n) {
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < n) {
+ y[idx] = x[idx] + 1;
+ }
+ }
+
+ void AddOneCUDA(DLTensor* x, DLTensor* y) {
+ // implementation of a library function
+ TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+ DLDataType f32_dtype{kDLFloat, 32, 1};
+ TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float
tensor";
+ TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+ TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float
tensor";
+ TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have
the same shape";
+
+ int64_t n = x->shape[0];
+ int64_t nthread_per_block = 256;
+ int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
+ // Obtain the current stream from the environment
+ // it will be set to torch.cuda.current_stream() when calling
the function
+ // with torch.Tensors
+ cudaStream_t stream = static_cast<cudaStream_t>(
+ TVMFFIEnvGetCurrentStream(x->device.device_type,
x->device.device_id));
+ // launch the kernel
+ AddOneKernel<<<nblock, nthread_per_block, 0,
stream>>>(static_cast<float*>(x->data),
+
static_cast<float*>(y->data), n);
+ }
+ """,
+ cpp_functions={"add_one_cpu": "AddOne"},
+ cuda_functions={"add_one_cuda": "AddOneCUDA"},
+ )
+
+ x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
+ y = torch.empty_like(x)
+ mod.add_one_cpu(x, y)
+ torch.testing.assert_close(x + 1, y)
+
+ x_cuda = x.cuda()
+ y_cuda = torch.empty_like(x_cuda)
+ mod.add_one_cuda(x_cuda, y_cuda)
+ torch.testing.assert_close(x_cuda + 1, y_cuda)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/ffi/python/tvm_ffi/cpp/__init__.py
b/ffi/python/tvm_ffi/cpp/__init__.py
new file mode 100644
index 0000000000..632698f443
--- /dev/null
+++ b/ffi/python/tvm_ffi/cpp/__init__.py
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from .load_inline import load_inline
diff --git a/ffi/python/tvm_ffi/cpp/load_inline.py
b/ffi/python/tvm_ffi/cpp/load_inline.py
new file mode 100644
index 0000000000..a9ec1c3997
--- /dev/null
+++ b/ffi/python/tvm_ffi/cpp/load_inline.py
@@ -0,0 +1,382 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Sequence, Optional, Mapping
+import os
+import sys
+import glob
+import hashlib
+import shutil
+import subprocess
+import functools
+
+from tvm_ffi.module import Module, load_module
+from tvm_ffi.utils import FileLock
+from tvm_ffi.libinfo import find_include_path, find_dlpack_include_path
+
+IS_WINDOWS = sys.platform == "win32"
+
+
+def _hash_sources(
+ cpp_source: str,
+ cuda_source: str,
+ cpp_functions: Mapping[str, str],
+ cuda_functions: Mapping[str, str],
+ extra_cflags: Sequence[str],
+ extra_cuda_cflags: Sequence[str],
+ extra_ldflags: Sequence[str],
+ extra_include_paths: Sequence[str],
+) -> str:
+ """Generate a unique hash for the given sources and functions."""
+ m = hashlib.sha256()
+ m.update(cpp_source.encode("utf-8"))
+ m.update(cuda_source.encode("utf-8"))
+ for name, doc in sorted(cpp_functions.items()):
+ m.update(name.encode("utf-8"))
+ m.update(doc.encode("utf-8"))
+ for name, doc in sorted(cuda_functions.items()):
+ m.update(name.encode("utf-8"))
+ m.update(doc.encode("utf-8"))
+ for flag in extra_cflags:
+ m.update(flag.encode("utf-8"))
+ for flag in extra_cuda_cflags:
+ m.update(flag.encode("utf-8"))
+ for flag in extra_ldflags:
+ m.update(flag.encode("utf-8"))
+ for path in extra_include_paths:
+ m.update(path.encode("utf-8"))
+ return m.hexdigest()[:16]
+
+
+def _maybe_write(path: str, content: str) -> None:
+ """Write content to path if it does not already exist with the same
content."""
+ if os.path.exists(path):
+ with open(path, "r") as f:
+ existing_content = f.read()
+ if existing_content == content:
+ return
+ with open(path, "w") as f:
+ f.write(content)
+
+
[email protected]_cache
+def _find_cuda_home() -> Optional[str]:
+ """Find the CUDA install path."""
+ # Guess #1
+ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
+ if cuda_home is None:
+ # Guess #2
+ nvcc_path = shutil.which("nvcc")
+ if nvcc_path is not None:
+ cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
+ else:
+ # Guess #3
+ if IS_WINDOWS:
+ cuda_homes = glob.glob("C:/Program Files/NVIDIA GPU Computing
Toolkit/CUDA/v*.*")
+ if len(cuda_homes) == 0:
+ cuda_home = ""
+ else:
+ cuda_home = cuda_homes[0]
+ else:
+ cuda_home = "/usr/local/cuda"
+ if not os.path.exists(cuda_home):
+ raise RuntimeError(
+ "Could not find CUDA installation. "
+ "Please set CUDA_HOME environment variable."
+ )
+ return cuda_home
+
+
+def _get_cuda_target() -> str:
+ """Get the CUDA target architecture flag."""
+ if "TVM_FFI_CUDA_ARCH_LIST" in os.environ:
+ arch_list = os.environ["TVM_FFI_CUDA_ARCH_LIST"].split() # e.g., "8.9
9.0a"
+ flags = []
+ for arch in arch_list:
+ if len(arch.split(".")) != 2:
+ raise ValueError(f"Invalid CUDA architecture: {arch}")
+ major, minor = arch.split(".")
+
flags.append(f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}")
+ return " ".join(flags)
+ else:
+ #
+ try:
+ status = subprocess.run(
+ args=["nvidia-smi", "--query-gpu=compute_cap",
"--format=csv,noheader"],
+ capture_output=True,
+ check=True,
+ )
+ compute_cap = status.stdout.decode("utf-8").strip().split("\n")[0]
+ major, minor = compute_cap.split(".")
+ return
f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}"
+ except Exception:
+ # fallback to a reasonable default
+ return "-gencode=arch=compute_70,code=sm_70"
+
+
+def _generate_ninja_build(
+ name: str,
+ build_dir: str,
+ with_cuda: bool,
+ extra_cflags: Sequence[str],
+ extra_cuda_cflags: Sequence[str],
+ extra_ldflags: Sequence[str],
+ extra_include_paths: Sequence[str],
+) -> str:
+ """Generate the content of build.ninja for building the module."""
+ default_include_paths = [find_include_path(), find_dlpack_include_path()]
+
+ if IS_WINDOWS:
+ default_cflags = ["/std:c++17"]
+ default_cuda_cflags = ["-Xcompiler", "/std:c++17", "/O2"]
+ default_ldflags = ["/DLL"]
+ else:
+ default_cflags = ["-std=c++17", "-fPIC", "-O2"]
+ default_cuda_cflags = ["-Xcompiler", "-fPIC", "-std=c++17", "-O2"]
+ default_ldflags = ["-shared"]
+
+ if with_cuda:
+ # determine the compute capability of the current GPU
+ default_cuda_cflags += [_get_cuda_target()]
+ default_ldflags += ["-L{}".format(os.path.join(_find_cuda_home(),
"lib64")), "-lcudart"]
+
+ cflags = default_cflags + [flag.strip() for flag in extra_cflags]
+ cuda_cflags = default_cuda_cflags + [flag.strip() for flag in
extra_cuda_cflags]
+ ldflags = default_ldflags + [flag.strip() for flag in extra_ldflags]
+ include_paths = default_include_paths + [os.path.abspath(path) for path in
extra_include_paths]
+
+ # append include paths
+ for path in include_paths:
+ cflags.append("-I{}".format(path))
+ cuda_cflags.append("-I{}".format(path))
+
+ # flags
+ ninja = []
+ ninja.append("ninja_required_version = 1.3")
+ ninja.append("cxx = {}".format(os.environ.get("CXX", "cl" if IS_WINDOWS
else "c++")))
+ ninja.append("cflags = {}".format(" ".join(cflags)))
+ if with_cuda:
+ ninja.append("nvcc = {}".format(os.path.join(_find_cuda_home(), "bin",
"nvcc")))
+ ninja.append("cuda_cflags = {}".format(" ".join(cuda_cflags)))
+ ninja.append("ldflags = {}".format(" ".join(ldflags)))
+
+ # rules
+ ninja.append("")
+ ninja.append("rule compile")
+ ninja.append(" depfile = $out.d")
+ ninja.append(" deps = gcc")
+ ninja.append(" command = $cxx -MMD -MF $out.d $cflags -c $in -o $out")
+ ninja.append("")
+
+ if with_cuda:
+ ninja.append("rule compile_cuda")
+ ninja.append(" depfile = $out.d")
+ ninja.append(" deps = gcc")
+ ninja.append(
+ " command = $nvcc --generate-dependencies-with-compile
--dependency-output $out.d $cuda_cflags -c $in -o $out"
+ )
+ ninja.append("")
+
+ ninja.append("rule link")
+ ninja.append(" command = $cxx $in $ldflags -o $out")
+ ninja.append("")
+
+ # build targets
+ ninja.append(
+ "build main.o: compile
{}".format(os.path.abspath(os.path.join(build_dir, "main.cpp")))
+ )
+ if with_cuda:
+ ninja.append(
+ "build cuda.o: compile_cuda {}".format(
+ os.path.abspath(os.path.join(build_dir, "cuda.cu"))
+ )
+ )
+ ninja.append("build {}.so: link main.o{}".format(name, " cuda.o" if
with_cuda else ""))
+ ninja.append("")
+
+ # default target
+ ninja.append("default {}.so".format(name))
+ ninja.append("")
+ return "\n".join(ninja)
+
+
+def _build_ninja(build_dir: str) -> None:
+ """Build the module in the given build directory using ninja."""
+ command = ["ninja", "-v"]
+ num_workers = os.environ.get("MAX_JOBS", None)
+ if num_workers is not None:
+ command += ["-j", num_workers]
+ status = subprocess.run(args=command, cwd=build_dir, capture_output=True)
+ if status.returncode != 0:
+ msg = ["ninja exited with status {}".format(status.returncode)]
+ if status.stdout:
+ msg.append("stdout:\n{}".format(status.stdout.decode("utf-8")))
+ if status.stderr:
+ msg.append("stderr:\n{}".format(status.stderr.decode("utf-8")))
+
+ raise RuntimeError("\n".join(msg))
+
+
+def _decorate_with_tvm_ffi(source: str, functions: Mapping[str, str]) -> str:
+ """Decorate the given source code with TVM FFI export macros."""
+ sources = [
+ "#include <tvm/ffi/dtype.h>",
+ "#include <tvm/ffi/error.h>",
+ "#include <tvm/ffi/extra/c_env_api.h>",
+ "#include <tvm/ffi/function.h>",
+ "",
+ source,
+ ]
+
+ for exported_name, func_name_in_source in functions.items():
+ sources.append(f"TVM_FFI_DLL_EXPORT_TYPED_FUNC({exported_name},
{func_name_in_source});")
+ sources.append("")
+
+ return "\n".join(sources)
+
+
+def load_inline(
+ name: str,
+ *,
+ cpp_source: str | None = None,
+ cuda_source: str | None = None,
+ cpp_functions: Mapping[str, str] | None = None,
+ cuda_functions: Mapping[str, str] | None = None,
+ extra_cflags: Sequence[str] | None = None,
+ extra_cuda_cflags: Sequence[str] | None = None,
+ extra_ldflags: Sequence[str] | None = None,
+ extra_include_paths: Sequence[str] | None = None,
+) -> Module:
+ """Compile and load a C++/CUDA tvm ffi module from inline source code.
+
+ This function compiles the given C++ and/or CUDA source code into a shared
library. Both cpp_source and cuda_source
+ are compiled to an object file, and then linked together into a shared
library. It's possible to only provide
+ cpp_source or cuda_source.
+
+ The `cpp_functions` and `cuda_functions` parameters are used to specify
which functions in the source code
+ should be exported to the tvm ffi module. The keys of the mapping are the
names of the exported functions, and the
+ values are the names of the functions in the source code. The exported
name and the function name in the source code
+ must be different. The exported name must be a valid C identifier while
the function name in the source code can
+ contain namespace qualifiers.
+
+ Extra compiler and linker flags can be provided via the `extra_cflags`,
`extra_cuda_cflags`, and `extra_ldflags`
+ parameters. The default flags are generally sufficient for most use cases,
but you may need to provide additional
+ flags for your specific use case.
+
+ The include dir of tvm ffi and dlpack are used by default for linker to
find the headers. Thus, you can include
+ any header from tvm ffi and dlpack in your source code. You can also
provide additional include paths via the
+ `extra_include_paths` parameter and include custom headers in your source
code.
+
+ The compiled shared library is cached in a cache directory to avoid
recompilation. The cache directory can be
+ specified via the `TVM_FFI_CACHE_DIR` environment variable. If not
specified, the default cache directory is
+ `~/.cache/tvm-ffi`.
+
+ Parameters
+ ----------
+ name: str
+ The name of the tvm ffi module.
+ cpp_source: str, optional
+ The C++ source code.
+ cuda_source: str, optional
+ The CUDA source code.
+ cpp_functions: Mapping[str, str], optional
+ The mapping from the exported function name to the function name in
the C++ source code.
+ cuda_functions: Mapping[str, str], optional
+ The mapping from the exported function name to the function name in
the CUDA source code.
+ extra_cflags: Sequence[str], optional
+ The extra compiler flags for C++ compilation.
+ The default flags are:
+ - On Linux/macOS: ['-std=c++17', '-fPIC', '-O2']
+ - On Windows: ['/std:c++17']
+ extra_cuda_cflags:
+ The extra compiler flags for CUDA compilation.
+ The default flags are:
+ - On Linux/macOS: ['-Xcompiler', '-fPIC', '-std=c++17', '-O2']
+ - On Windows: ['-Xcompiler', '/std:c++17', '/O2']
+ extra_ldflags: Sequence[str], optional
+ The extra linker flags.
+ The default flags are:
+ - On Linux/macOS: ['-shared']
+ - On Windows: ['/DLL']
+ extra_include_paths: Sequence[str], optional
+ The extra include paths.
+ The default include paths are:
+ - The include path of tvm ffi
+ Returns
+ -------
+ mod: Module
+ The loaded tvm ffi module.
+ """
+ if cpp_source is None:
+ cpp_source = ""
+ if cuda_source is None:
+ cuda_source = ""
+ if cpp_functions is None:
+ cpp_functions = {}
+ if cuda_functions is None:
+ cuda_functions = {}
+ extra_ldflags = extra_ldflags or []
+ extra_cflags = extra_cflags or []
+ extra_cuda_cflags = extra_cuda_cflags or []
+ extra_include_paths = extra_include_paths or []
+
+ # whether we have cuda source in this module
+ with_cuda = len(cuda_source.strip()) > 0
+
+ # add function registration code to sources
+ cpp_source = _decorate_with_tvm_ffi(cpp_source, cpp_functions)
+ cuda_source = _decorate_with_tvm_ffi(cuda_source, cuda_functions)
+
+ # determine the cache dir for the built module
+ cache_dir = os.path.join(
+ os.environ.get("TVM_FFI_CACHE_DIR",
os.path.expanduser("~/.cache/tvm-ffi"))
+ )
+ source_hash: str = _hash_sources(
+ cpp_source,
+ cuda_source,
+ cpp_functions,
+ cuda_functions,
+ extra_cflags,
+ extra_cuda_cflags,
+ extra_ldflags,
+ extra_include_paths,
+ )
+ build_dir: str = os.path.join(cache_dir, "{}_{}".format(name, source_hash))
+ os.makedirs(build_dir, exist_ok=True)
+
+ # generate build.ninja
+ ninja_source = _generate_ninja_build(
+ name=name,
+ build_dir=build_dir,
+ with_cuda=with_cuda,
+ extra_cflags=extra_cflags,
+ extra_cuda_cflags=extra_cuda_cflags,
+ extra_ldflags=extra_ldflags,
+ extra_include_paths=extra_include_paths,
+ )
+
+ with FileLock(os.path.join(build_dir, "lock")):
+ # write source files and build.ninja if they do not already exist
+ _maybe_write(os.path.join(build_dir, "main.cpp"), cpp_source)
+ if with_cuda:
+ _maybe_write(os.path.join(build_dir, "cuda.cu"), cuda_source)
+ _maybe_write(os.path.join(build_dir, "build.ninja"), ninja_source)
+
+ # build the module
+ _build_ninja(build_dir)
+
+ return load_module(os.path.join(build_dir, "{}.so".format(name)))
diff --git a/ffi/python/tvm_ffi/utils/__init__.py
b/ffi/python/tvm_ffi/utils/__init__.py
new file mode 100644
index 0000000000..543bd0f841
--- /dev/null
+++ b/ffi/python/tvm_ffi/utils/__init__.py
@@ -0,0 +1,18 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from .lockfile import FileLock
diff --git a/ffi/python/tvm_ffi/utils/lockfile.py
b/ffi/python/tvm_ffi/utils/lockfile.py
new file mode 100644
index 0000000000..3b3197e2d8
--- /dev/null
+++ b/ffi/python/tvm_ffi/utils/lockfile.py
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import sys
+import time
+
+# Platform-specific imports for file locking
+if sys.platform == "win32":
+ import msvcrt
+else:
+ import fcntl
+
+
+class FileLock:
+ """
+ A cross-platform file locking mechanism using Python's standard library.
+ This class implements an advisory lock, which must be respected by all
+ cooperating processes.
+ """
+
+ def __init__(self, lock_file_path):
+ self.lock_file_path = lock_file_path
+ self._file_descriptor = None
+
+ def __enter__(self):
+ """
+ Context manager protocol: acquire the lock upon entering the 'with'
block.
+ This method will block indefinitely until the lock is acquired.
+ """
+ self.blocking_acquire()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """
+ Context manager protocol: release the lock upon exiting the 'with'
block.
+ """
+ self.release()
+ return False # Propagate exceptions, if any
+
+ def acquire(self):
+ """
+ Acquires an exclusive, non-blocking lock on the file.
+ Returns True if the lock was acquired, False otherwise.
+ """
+ try:
+ if sys.platform == "win32":
+ self._file_descriptor = os.open(
+ self.lock_file_path, os.O_RDWR | os.O_CREAT | os.O_BINARY
+ )
+ msvcrt.locking(self._file_descriptor, msvcrt.LK_NBLCK, 1)
+ else: # Unix-like systems
+ self._file_descriptor = os.open(self.lock_file_path,
os.O_WRONLY | os.O_CREAT)
+ fcntl.flock(self._file_descriptor, fcntl.LOCK_EX |
fcntl.LOCK_NB)
+ return True
+ except (IOError, BlockingIOError):
+ if self._file_descriptor is not None:
+ os.close(self._file_descriptor)
+ self._file_descriptor = None
+ return False
+ except Exception as e:
+ if self._file_descriptor is not None:
+ os.close(self._file_descriptor)
+ self._file_descriptor = None
+ raise RuntimeError(f"An unexpected error occurred: {e}")
+
+ def blocking_acquire(self, timeout=None, poll_interval=0.1):
+ """
+ Waits until an exclusive lock can be acquired, with an optional
timeout.
+
+ Args:
+ timeout (float): The maximum time to wait for the lock in seconds.
+ A value of None means wait indefinitely.
+ poll_interval (float): The time to wait between lock attempts in
seconds.
+ """
+ start_time = time.time()
+ while True:
+ if self.acquire():
+ return True
+
+ # Check for timeout
+ if timeout is not None and (time.time() - start_time) > timeout:
+ raise TimeoutError(
+ f"Failed to acquire lock on '{self.lock_file_path}' after
{timeout} seconds."
+ )
+
+ time.sleep(poll_interval)
+
+ def release(self):
+ """
+ Releases the lock and closes the file descriptor.
+ """
+ if self._file_descriptor is not None:
+ if sys.platform == "win32":
+ msvcrt.locking(self._file_descriptor, msvcrt.LK_UNLCK, 1)
+ else:
+ fcntl.flock(self._file_descriptor, fcntl.LOCK_UN)
+ os.close(self._file_descriptor)
+ self._file_descriptor = None
diff --git a/ffi/tests/python/test_load_inline.py
b/ffi/tests/python/test_load_inline.py
new file mode 100644
index 0000000000..bb14ae9792
--- /dev/null
+++ b/ffi/tests/python/test_load_inline.py
@@ -0,0 +1,161 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import numpy
+
+try:
+ import torch
+except ImportError:
+ torch = None
+
+import tvm_ffi.cpp
+from tvm_ffi.module import Module
+
+
+def test_load_inline_cpp():
+ mod: Module = tvm_ffi.cpp.load_inline(
+ name="hello",
+ cpp_source=r"""
+ void AddOne(DLTensor* x, DLTensor* y) {
+ // implementation of a library function
+ TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+ DLDataType f32_dtype{kDLFloat, 32, 1};
+ TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float
tensor";
+ TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+ TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float
tensor";
+ TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have
the same shape";
+ for (int i = 0; i < x->shape[0]; ++i) {
+ static_cast<float*>(y->data)[i] =
static_cast<float*>(x->data)[i] + 1;
+ }
+ }
+ """,
+ cpp_functions={"add_one_cpu": "AddOne"},
+ )
+
+ x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32)
+ y = numpy.empty_like(x)
+ mod.add_one_cpu(x, y)
+ numpy.testing.assert_equal(x + 1, y)
+
+
[email protected](reason="Requires CUDA")
+def test_load_inline_cuda():
+ mod: Module = tvm_ffi.cpp.load_inline(
+ name="hello",
+ cuda_source=r"""
+ __global__ void AddOneKernel(float* x, float* y, int n) {
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < n) {
+ y[idx] = x[idx] + 1;
+ }
+ }
+
+ void AddOneCUDA(DLTensor* x, DLTensor* y) {
+ // implementation of a library function
+ TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+ DLDataType f32_dtype{kDLFloat, 32, 1};
+ TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float
tensor";
+ TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+ TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float
tensor";
+ TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have
the same shape";
+
+ int64_t n = x->shape[0];
+ int64_t nthread_per_block = 256;
+ int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
+ // Obtain the current stream from the environment
+ // it will be set to torch.cuda.current_stream() when calling
the function
+ // with torch.Tensors
+ cudaStream_t stream = static_cast<cudaStream_t>(
+ TVMFFIEnvGetCurrentStream(x->device.device_type,
x->device.device_id));
+ // launch the kernel
+ AddOneKernel<<<nblock, nthread_per_block, 0,
stream>>>(static_cast<float*>(x->data),
+
static_cast<float*>(y->data), n);
+ }
+ """,
+ cuda_functions={"add_one_cuda": "AddOneCUDA"},
+ )
+
+ if torch is not None:
+ x_cuda = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32,
device="cuda")
+ y_cuda = torch.empty_like(x_cuda)
+ mod.add_one_cuda(x_cuda, y_cuda)
+ torch.testing.assert_close(x_cuda + 1, y_cuda)
+
+
[email protected](reason="Requires CUDA")
+def test_load_inline_both():
+ mod: Module = tvm_ffi.cpp.load_inline(
+ name="hello",
+ cpp_source=r"""
+ void AddOne(DLTensor* x, DLTensor* y) {
+ // implementation of a library function
+ TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+ DLDataType f32_dtype{kDLFloat, 32, 1};
+ TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float
tensor";
+ TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+ TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float
tensor";
+ TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have
the same shape";
+ for (int i = 0; i < x->shape[0]; ++i) {
+ static_cast<float*>(y->data)[i] =
static_cast<float*>(x->data)[i] + 1;
+ }
+ }
+ """,
+ cuda_source=r"""
+ __global__ void AddOneKernel(float* x, float* y, int n) {
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx < n) {
+ y[idx] = x[idx] + 1;
+ }
+ }
+
+ void AddOneCUDA(DLTensor* x, DLTensor* y) {
+ // implementation of a library function
+ TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+ DLDataType f32_dtype{kDLFloat, 32, 1};
+ TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float
tensor";
+ TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
+ TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float
tensor";
+ TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have
the same shape";
+
+ int64_t n = x->shape[0];
+ int64_t nthread_per_block = 256;
+ int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
+ // Obtain the current stream from the environment
+ // it will be set to torch.cuda.current_stream() when calling
the function
+ // with torch.Tensors
+ cudaStream_t stream = static_cast<cudaStream_t>(
+ TVMFFIEnvGetCurrentStream(x->device.device_type,
x->device.device_id));
+ // launch the kernel
+ AddOneKernel<<<nblock, nthread_per_block, 0,
stream>>>(static_cast<float*>(x->data),
+
static_cast<float*>(y->data), n);
+ }
+ """,
+ cpp_functions={"add_one_cpu": "AddOne"},
+ cuda_functions={"add_one_cuda": "AddOneCUDA"},
+ )
+
+ x = numpy.array([1, 2, 3, 4, 5], dtype=numpy.float32)
+ y = numpy.empty_like(x)
+ mod.add_one_cpu(x, y)
+ numpy.testing.assert_equal(x + 1, y)
+
+ if torch is not None:
+ x_cuda = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32,
device="cuda")
+ y_cuda = torch.empty_like(x_cuda)
+ mod.add_one_cuda(x_cuda, y_cuda)
+ torch.testing.assert_close(x_cuda + 1, y_cuda)