This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 1af6d9f feat(stubgen): Refactor into staged pipeline, Introduce
directive `import` (#259)
1af6d9f is described below
commit 1af6d9f9648bb2547d97d455bb4464d217654fc4
Author: Junru Shao <[email protected]>
AuthorDate: Sat Nov 15 09:38:28 2025 -0800
feat(stubgen): Refactor into staged pipeline, Introduce directive `import`
(#259)
This PR rebuilds `tvm-ffi-stubgen` into a staged pipeline:
- parsing now lives in `python/tvm_ffi/stub/file_utils.py`
- directive/format constants moved to `python/tvm_ffi/stub/consts.py`,
and
- signature/field generation lives in `python/tvm_ffi/stub/codegen.py`.
- The CLI (`python/tvm_ffi/stub/stubgen.py`) now scans whole trees,
honors file-level `ty-map` directives, can run in `--dry-run` /
`--verbose` mode, and auto-infers which imports must be emitted inside a
new `tvm-ffi-stubgen(begin): import` block.
The new directives are adopted across the public Python surface:
`_ffi_api.py`, `access_path.py`, `module.py`, and `testing.py`. They all
- gained import blocks
- switched to the `ty-map` syntax, and
- let the generator emit PEP 526 attribute annotations outside of
`TYPE_CHECKING` while keeping method shims guarded.
A new directive `import` is introduced to auto import types needed by
generated stubs, once it's placed via:
```python
# tvm-ffi-stubgen(begin): import
# tvm-ffi-stubgen(end): import
```
the stubgen tool will expand it into:
```diff
# tvm-ffi-stubgen(begin): import
+ # fmt: off
+ # isort: off
+ from __future__ import annotations
+ from typing import Any, Callable, TYPE_CHECKING
+ if TYPE_CHECKING:
+ from collections.abc import Mapping, Sequence
+ from tvm_ffi import Module
+ from tvm_ffi.access_path import AccessPath
+ # isort: on
+ # fmt: on
# tvm-ffi-stubgen(end)
```
---
pyproject.toml | 6 +-
python/tvm_ffi/_ffi_api.py | 18 +-
python/tvm_ffi/access_path.py | 54 ++--
python/tvm_ffi/module.py | 21 +-
python/tvm_ffi/stub/analysis.py | 62 +++++
python/tvm_ffi/stub/cli.py | 196 ++++++++++++++
python/tvm_ffi/stub/codegen.py | 183 +++++++++++++
python/tvm_ffi/stub/consts.py | 56 ++++
python/tvm_ffi/stub/file_utils.py | 240 +++++++++++++++++
python/tvm_ffi/stub/stubgen.py | 528 --------------------------------------
python/tvm_ffi/stub/utils.py | 32 +++
python/tvm_ffi/testing.py | 93 +++----
tests/python/test_stubgen.py | 182 -------------
13 files changed, 877 insertions(+), 794 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 3630002..95c4f44 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -100,7 +100,7 @@ docs = [
[project.scripts]
tvm-ffi-config = "tvm_ffi.config:__main__"
-tvm-ffi-stubgen = "tvm_ffi.stub.stubgen:__main__"
+tvm-ffi-stubgen = "tvm_ffi.stub.cli:__main__"
[build-system]
requires = ["scikit-build-core>=0.10.0", "cython>=3.0", "setuptools-scm"]
@@ -274,6 +274,10 @@ exclude = '''(?x)(
module = ["torch", "torch.*", "my_ffi_extension", "my_ffi_extension.*"]
ignore_missing_imports = true
+[[tool.mypy.overrides]]
+module = "_pytest.*"
+follow_imports = "skip"
+
[tool.uv.dependency-groups]
docs = { requires-python = ">=3.13" }
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index 316a6fb..1bdef7a 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -16,22 +16,24 @@
# under the License.
"""FFI API."""
+# tvm-ffi-stubgen(begin): import
+# fmt: off
+# isort: off
from __future__ import annotations
-
-from typing import TYPE_CHECKING, Any, Callable
-
-from . import registry
-
+from typing import Any, Callable, TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
-
from tvm_ffi import Module
from tvm_ffi.access_path import AccessPath
+# isort: on
+# fmt: on
+# tvm-ffi-stubgen(end)
+from . import registry
# tvm-ffi-stubgen(begin): global/ffi
+# fmt: off
if TYPE_CHECKING:
- # fmt: off
def Array(*args: Any) -> Any: ...
def ArrayGetItem(_0: Sequence[Any], _1: int, /) -> Any: ...
def ArraySize(_0: Sequence[Any], /) -> int: ...
@@ -67,7 +69,7 @@ if TYPE_CHECKING:
def SystemLib(*args: Any) -> Any: ...
def ToJSONGraph(_0: Any, _1: Any, /) -> Any: ...
def ToJSONGraphString(_0: Any, _1: Any, /) -> str: ...
- # fmt: on
+# fmt: on
# tvm-ffi-stubgen(end)
registry.init_ffi_api("ffi", __name__)
diff --git a/python/tvm_ffi/access_path.py b/python/tvm_ffi/access_path.py
index b69e05d..89d1b3c 100644
--- a/python/tvm_ffi/access_path.py
+++ b/python/tvm_ffi/access_path.py
@@ -17,11 +17,18 @@
# pylint: disable=invalid-name
"""Access path classes."""
+# tvm-ffi-stubgen(begin): import
+# fmt: off
+# isort: off
from __future__ import annotations
-
-from collections.abc import Sequence
+from typing import Any, TYPE_CHECKING
+if TYPE_CHECKING:
+ from collections.abc import Sequence
+ from tvm_ffi import Object
+# isort: on
+# fmt: on
+# tvm-ffi-stubgen(end)
from enum import IntEnum
-from typing import TYPE_CHECKING, Any
from .core import Object
from .registry import register_object
@@ -42,12 +49,12 @@ class AccessKind(IntEnum):
class AccessStep(Object):
"""Access step container."""
+ # tvm-ffi-stubgen(ty-map): ffi.reflection.AccessStep ->
ffi.access_path.AccessStep
# tvm-ffi-stubgen(begin): object/ffi.reflection.AccessStep
- if TYPE_CHECKING:
- # fmt: off
- kind: int
- key: Any
- # fmt: on
+ # fmt: off
+ kind: int
+ key: Any
+ # fmt: on
# tvm-ffi-stubgen(end)
@@ -73,25 +80,26 @@ class AccessPath(Object):
"""
+ # tvm-ffi-stubgen(ty-map): ffi.reflection.AccessPath ->
ffi.access_path.AccessPath
# tvm-ffi-stubgen(begin): object/ffi.reflection.AccessPath
+ # fmt: off
+ parent: Object | None
+ step: AccessStep | None
+ depth: int
if TYPE_CHECKING:
- # fmt: off
- parent: Object | None
- step: AccessStep | None
- depth: int
@staticmethod
def _root() -> AccessPath: ...
- def _extend(_0: AccessPath, _1: AccessStep, /) -> AccessPath: ...
- def _attr(_0: AccessPath, _1: str, /) -> AccessPath: ...
- def _array_item(_0: AccessPath, _1: int, /) -> AccessPath: ...
- def _map_item(_0: AccessPath, _1: Any, /) -> AccessPath: ...
- def _attr_missing(_0: AccessPath, _1: str, /) -> AccessPath: ...
- def _array_item_missing(_0: AccessPath, _1: int, /) -> AccessPath: ...
- def _map_item_missing(_0: AccessPath, _1: Any, /) -> AccessPath: ...
- def _is_prefix_of(_0: AccessPath, _1: AccessPath, /) -> bool: ...
- def _to_steps(_0: AccessPath, /) -> Sequence[AccessStep]: ...
- def _path_equal(_0: AccessPath, _1: AccessPath, /) -> bool: ...
- # fmt: on
+ def _extend(self, _1: AccessStep, /) -> AccessPath: ...
+ def _attr(self, _1: str, /) -> AccessPath: ...
+ def _array_item(self, _1: int, /) -> AccessPath: ...
+ def _map_item(self, _1: Any, /) -> AccessPath: ...
+ def _attr_missing(self, _1: str, /) -> AccessPath: ...
+ def _array_item_missing(self, _1: int, /) -> AccessPath: ...
+ def _map_item_missing(self, _1: Any, /) -> AccessPath: ...
+ def _is_prefix_of(self, _1: AccessPath, /) -> bool: ...
+ def _to_steps(self, /) -> Sequence[AccessStep]: ...
+ def _path_equal(self, _1: AccessPath, /) -> bool: ...
+ # fmt: on
# tvm-ffi-stubgen(end)
def __init__(self) -> None:
diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py
index 73970e0..080d431 100644
--- a/python/tvm_ffi/module.py
+++ b/python/tvm_ffi/module.py
@@ -16,13 +16,19 @@
# under the License.
"""Module related objects and functions."""
-# pylint: disable=invalid-name
+# tvm-ffi-stubgen(begin): import
+# fmt: off
+# isort: off
from __future__ import annotations
-
-from collections.abc import Sequence
+from typing import Any, TYPE_CHECKING
+if TYPE_CHECKING:
+ from collections.abc import Sequence
+# isort: on
+# fmt: on
+# tvm-ffi-stubgen(end)
from enum import IntEnum
from os import PathLike, fspath
-from typing import TYPE_CHECKING, Any, ClassVar, cast
+from typing import ClassVar, cast
from . import _ffi_api, core
from .registry import register_object
@@ -87,10 +93,9 @@ class Module(core.Object):
"""
# tvm-ffi-stubgen(begin): object/ffi.Module
- if TYPE_CHECKING:
- # fmt: off
- imports_: Sequence[Any]
- # fmt: on
+ # fmt: off
+ imports_: Sequence[Any]
+ # fmt: on
# tvm-ffi-stubgen(end)
entry_name: ClassVar[str] = "main" # constant for entry function name
diff --git a/python/tvm_ffi/stub/analysis.py b/python/tvm_ffi/stub/analysis.py
new file mode 100644
index 0000000..2207397
--- /dev/null
+++ b/python/tvm_ffi/stub/analysis.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Analysis utilities."""
+
+from __future__ import annotations
+
+from tvm_ffi.registry import list_global_func_names
+
+from . import consts as C
+from .file_utils import FileInfo
+from .utils import Options
+
+
+def collect_global_funcs() -> dict[str, list[str]]:
+ """Collect global functions from TVM FFI's global registry."""
+ # Build global function table only if we are going to process blocks.
+ global_funcs: dict[str, list[str]] = {}
+ for name in list_global_func_names():
+ try:
+ prefix, suffix = name.rsplit(".", 1)
+ except ValueError:
+ print(f"{C.TERM_YELLOW}[Skipped] Invalid name in global function:
{name}{C.TERM_RESET}")
+ else:
+ global_funcs.setdefault(prefix, []).append(suffix)
+ # Ensure stable ordering for deterministic output.
+ for k in list(global_funcs.keys()):
+ global_funcs[k].sort()
+ return global_funcs
+
+
+def collect_ty_maps(files: list[FileInfo], opt: Options) -> dict[str, str]:
+ """Collect type maps from the given files."""
+ ty_map: dict[str, str] = C.TY_MAP_DEFAULTS.copy()
+ for file in files:
+ for code in file.code_blocks:
+ if code.kind == "ty-map":
+ try:
+ lhs, rhs = code.param.split("->")
+ except ValueError as e:
+ raise ValueError(
+ f"Invalid ty_map format at line {code.lineno_start}.
Example: `A.B -> C.D`"
+ ) from e
+ ty_map[lhs.strip()] = rhs.strip()
+ if opt.verbose:
+ for lhs in sorted(ty_map):
+ rhs = ty_map[lhs]
+ print(f"{C.TERM_CYAN}[TY-MAP] {lhs} -> {rhs}{C.TERM_RESET}")
+ return ty_map
diff --git a/python/tvm_ffi/stub/cli.py b/python/tvm_ffi/stub/cli.py
new file mode 100644
index 0000000..5728e12
--- /dev/null
+++ b/python/tvm_ffi/stub/cli.py
@@ -0,0 +1,196 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# tvm-ffi-stubgen(skip-file)
+"""TVM-FFI Stub Generator (``tvm-ffi-stubgen``)."""
+
+from __future__ import annotations
+
+import argparse
+import ctypes
+import sys
+import traceback
+from pathlib import Path
+from typing import Callable
+
+from . import codegen as G
+from . import consts as C
+from .analysis import collect_global_funcs, collect_ty_maps
+from .file_utils import FileInfo, collect_files
+from .utils import Options
+
+
+def _fn_ty_map(ty_map: dict[str, str], ty_used: set[str]) -> Callable[[str],
str]:
+ def _run(name: str) -> str:
+ nonlocal ty_map, ty_used
+ if (ret := ty_map.get(name)) is not None:
+ name = ret
+ if (ret := C.TY_TO_IMPORT.get(name)) is not None:
+ name = ret
+ if "." in name:
+ ty_used.add(name)
+ return name.rsplit(".", 1)[-1]
+
+ return _run
+
+
+def __main__() -> int:
+ """Command line entry point for ``tvm-ffi-stubgen``.
+
+ This generates in-place type stubs inside special ``tvm-ffi-stubgen``
blocks
+ in the given files or directories. See the module docstring for an
+ overview and examples of the block syntax.
+ """
+ opt = _parse_args()
+ dlls = [ctypes.CDLL(lib) for lib in opt.dlls]
+ global_funcs = collect_global_funcs()
+ files: list[FileInfo] = collect_files([Path(f) for f in opt.files])
+
+ # Stage 1: Process `tvm-ffi-stubgen(ty-map)`
+ ty_map: dict[str, str] = collect_ty_maps(files, opt)
+
+ # Stage 2: Process
+ # - `tvm-ffi-stubgen(begin): global/...`
+ # - `tvm-ffi-stubgen(begin): object/...`
+
+ def _stage_2(file: FileInfo) -> None:
+ if opt.verbose:
+ print(f"{C.TERM_CYAN}[File] {file.path}{C.TERM_RESET}")
+ ty_used: set[str] = set()
+ ty_on_file: set[str] = set()
+ fn_ty_map_fn = _fn_ty_map(ty_map, ty_used)
+ # Stage 2.1. Process `tvm-ffi-stubgen(begin): global/...`
+ for code in file.code_blocks:
+ if code.kind == "global":
+ G.generate_global_funcs(code, global_funcs, fn_ty_map_fn, opt)
+ # Stage 2.2. Process `tvm-ffi-stubgen(begin): object/...`
+ for code in file.code_blocks:
+ if code.kind == "object":
+ G.generate_object(code, fn_ty_map_fn, opt)
+ ty_on_file.add(ty_map.get(code.param, code.param))
+ # Stage 2.3. Add imports for used types.
+ for code in file.code_blocks:
+ if code.kind == "import":
+ G.generate_imports(code, ty_used - ty_on_file, opt)
+ break # Only one import block per file is supported for now.
+ file.update(show_diff=opt.verbose, dry_run=opt.dry_run)
+
+ for file in files:
+ try:
+ _stage_2(file)
+ except:
+ print(
+ f'{C.TERM_RED}[Failed] File "{file.path}":
{traceback.format_exc()}{C.TERM_RESET}'
+ )
+ del dlls
+ return 0
+
+
+def _parse_args() -> Options:
+ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter,
argparse.RawTextHelpFormatter):
+ pass
+
+ parser = argparse.ArgumentParser(
+ prog="tvm-ffi-stubgen",
+ description=(
+ "Generate in-place type stubs for TVM FFI.\n\n"
+ "It scans .py/.pyi files for tvm-ffi-stubgen blocks and fills them
with\n"
+ "TYPE_CHECKING-only annotations derived from TVM runtime metadata."
+ ),
+ formatter_class=HelpFormatter,
+ epilog=(
+ "Examples:\n"
+ " # Single file\n"
+ " tvm-ffi-stubgen python/tvm_ffi/_ffi_api.py\n\n"
+ " # Recursively scan directories\n"
+ " tvm-ffi-stubgen python/tvm_ffi
examples/packaging/python/my_ffi_extension\n\n"
+ " # Preload TVM runtime / extension libraries\n"
+ " tvm-ffi-stubgen --dlls build/libtvm_runtime.so
build/libmy_ext.so my_pkg/_ffi_api.py\n\n"
+ "Stub block syntax (placed in your source):\n"
+ " # tvm-ffi-stubgen(begin): global/<registry-prefix>\n"
+ " ... generated function stubs ...\n"
+ " # tvm-ffi-stubgen(end)\n\n"
+ " # tvm-ffi-stubgen(begin): object/<type_key>\n"
+ " # tvm-ffi-stubgen(ty_map): list -> Sequence\n"
+ " # tvm-ffi-stubgen(ty_map): dict -> Mapping\n"
+ " ... generated fields and methods ...\n"
+ " # tvm-ffi-stubgen(end)\n\n"
+ " # Skip a file entirely\n"
+ " # tvm-ffi-stubgen(skip-file)\n\n"
+ "Tips:\n"
+ " - Only .py/.pyi files are updated; directories are scanned
recursively.\n"
+ " - Import any aliases you use in ty_map under TYPE_CHECKING,
e.g.\n"
+ " from collections.abc import Mapping, Sequence\n"
+ " - Use --dlls to preload shared libraries when function/type
metadata\n"
+ " is provided by native extensions.\n"
+ ),
+ )
+ parser.add_argument(
+ "--dlls",
+ nargs="*",
+ metavar="LIB",
+ help=(
+ "Shared libraries to preload before generation (e.g. TVM runtime
or "
+ "your extension). This ensures global function and object metadata
"
+ "is available. Accepts multiple paths; platform-specific suffixes "
+ "like .so/.dylib/.dll are supported."
+ ),
+ default=[],
+ )
+ parser.add_argument(
+ "--indent",
+ type=int,
+ default=4,
+ help=(
+ "Extra spaces added inside each generated block, relative to the "
+ "indentation of the corresponding '# tvm-ffi-stubgen(begin):'
line."
+ ),
+ )
+ parser.add_argument(
+ "files",
+ nargs="*",
+ metavar="PATH",
+ help=(
+ "Files or directories to process. Directories are scanned
recursively; "
+ "only .py and .pyi files are modified. Use tvm-ffi-stubgen markers
to "
+ "select where stubs are generated."
+ ),
+ )
+ parser.add_argument(
+ "--verbose",
+ action="store_true",
+ help=(
+ "Print a unified diff of changes to each file. This is useful for "
+ "debugging or previewing changes before applying them."
+ ),
+ )
+ parser.add_argument(
+ "--dry-run",
+ action="store_true",
+ help=(
+ "Don't write changes to files. This is useful for previewing
changes "
+ "without modifying any files."
+ ),
+ )
+ opt = Options(**vars(parser.parse_args()))
+ if not opt.files:
+ parser.print_help()
+ sys.exit(1)
+ return opt
+
+
+if __name__ == "__main__":
+ sys.exit(__main__())
diff --git a/python/tvm_ffi/stub/codegen.py b/python/tvm_ffi/stub/codegen.py
new file mode 100644
index 0000000..be33f16
--- /dev/null
+++ b/python/tvm_ffi/stub/codegen.py
@@ -0,0 +1,183 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Code generation logic for the `tvm-ffi-stubgen` tool."""
+
+from __future__ import annotations
+
+from io import StringIO
+from typing import Callable
+
+from tvm_ffi.core import TypeSchema,
_lookup_or_register_type_info_from_type_key
+from tvm_ffi.registry import get_global_func_metadata
+
+from . import consts as C
+from .file_utils import CodeBlock
+from .utils import Options
+
+
+def generate_func_signature(
+ schema: TypeSchema,
+ func_name: str,
+ ty_map: Callable[[str], str],
+ is_member: bool,
+) -> str:
+ """Generate a function signature string from a TypeSchema."""
+ buf = StringIO()
+ buf.write(f"def {func_name}(")
+ if schema.origin != "Callable":
+ raise ValueError(f"Expected Callable type schema, but got: {schema}")
+ if not schema.args:
+ ty_map("Any")
+ buf.write("*args: Any) -> Any: ...")
+ return buf.getvalue()
+ arg_ret = schema.args[0]
+ arg_args = schema.args[1:]
+ for i, arg in enumerate(arg_args):
+ if is_member and i == 0:
+ buf.write("self, ")
+ else:
+ buf.write(f"_{i}: ")
+ buf.write(arg.repr(ty_map))
+ buf.write(", ")
+ if arg_args:
+ buf.write("/")
+ buf.write(") -> ")
+ buf.write(arg_ret.repr(ty_map))
+ buf.write(": ...")
+ return buf.getvalue()
+
+
+def generate_global_funcs(
+ code: CodeBlock,
+ global_funcs: dict[str, list[str]],
+ fn_ty_map: Callable[[str], str],
+ opt: Options,
+) -> None:
+ """Generate function signatures for global functions."""
+ assert len(code.lines) >= 2
+ indent = " " * code.indent
+ indent_long = " " * (code.indent + opt.indent)
+ prefix = code.param
+ results: list[str] = [
+ generate_func_signature(
+
TypeSchema.from_json_str(get_global_func_metadata(f"{prefix}.{name}")["type_schema"]),
+ name,
+ ty_map=fn_ty_map,
+ is_member=False,
+ )
+ for name in global_funcs.get(prefix, [])
+ ]
+ if not results:
+ return
+ code.lines = [
+ code.lines[0],
+ f"{indent}# fmt: off",
+ f"{indent}if TYPE_CHECKING:",
+ *[indent_long + sig for sig in results],
+ f"{indent}# fmt: on",
+ code.lines[-1],
+ ]
+
+
+def generate_object(code: CodeBlock, fn_ty_map: Callable[[str], str], opt:
Options) -> None:
+ """Generate a class definition for an object type."""
+ assert len(code.lines) >= 2
+ type_key = code.param
+ type_info = _lookup_or_register_type_info_from_type_key(type_key)
+ indent = " " * code.indent
+ indent_long = " " * (code.indent + opt.indent)
+
+ fields: list[str] = []
+ for field in type_info.fields:
+ fields.append(
+ f"{indent}{field.name}: "
+ +
TypeSchema.from_json_str(field.metadata["type_schema"]).repr(fn_ty_map)
+ )
+
+ methods: list[str] = []
+ if type_info.methods:
+ methods = [f"{indent}if TYPE_CHECKING:"]
+ for method in type_info.methods:
+ if method.is_static:
+ methods.append(f"{indent_long}@staticmethod")
+ methods.append(
+ indent_long
+ + generate_func_signature(
+ TypeSchema.from_json_str(method.metadata["type_schema"]),
+ {
+ "__ffi_init__": "__c_ffi_init__",
+ }.get(method.name, method.name),
+ fn_ty_map,
+ is_member=not method.is_static,
+ )
+ )
+ code.lines = [
+ code.lines[0],
+ f"{indent}# fmt: off",
+ *fields,
+ *methods,
+ f"{indent}# fmt: on",
+ code.lines[-1],
+ ]
+
+
+def generate_imports(
+ code: CodeBlock,
+ ty_used: set[str],
+ opt: Options,
+) -> None:
+ """Generate import statements for the types used in the stub."""
+ ty_collected: dict[str, list[str]] = {}
+ for ty in ty_used:
+ assert "." in ty
+ module, name = ty.rsplit(".", 1)
+ for mod_prefix, mod_replacement in C.MOD_MAP.items():
+ if module.startswith(mod_prefix):
+ module = module.replace(mod_prefix, mod_replacement, 1)
+ break
+ ty_collected.setdefault(module, []).append(name)
+ if not ty_collected:
+ return
+
+ def _make_line(module: str, names: list[str], indent: int) -> str:
+ names = ", ".join(sorted(set(names)))
+ indent_str = " " * indent
+ return f"{indent_str}from {module} import {names}"
+
+ results: list[str] = [
+ "from __future__ import annotations",
+ _make_line(
+ "typing",
+ [*ty_collected.pop("typing", []), "TYPE_CHECKING"],
+ indent=0,
+ ),
+ ]
+ if ty_collected:
+ results.append("if TYPE_CHECKING:")
+ for module in sorted(ty_collected):
+ names = ty_collected[module]
+ results.append(_make_line(module, names, indent=opt.indent))
+ if results:
+ code.lines = [
+ code.lines[0],
+ "# fmt: off",
+ "# isort: off",
+ *results,
+ "# isort: on",
+ "# fmt: on",
+ code.lines[-1],
+ ]
diff --git a/python/tvm_ffi/stub/consts.py b/python/tvm_ffi/stub/consts.py
new file mode 100644
index 0000000..6443c95
--- /dev/null
+++ b/python/tvm_ffi/stub/consts.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Constants used in stub generation."""
+
+STUB_PREFIX = "# tvm-ffi-stubgen("
+STUB_BEGIN = f"{STUB_PREFIX}begin):"
+STUB_END = f"{STUB_PREFIX}end)"
+STUB_TY_MAP = f"{STUB_PREFIX}ty-map):"
+STUB_SKIP_FILE = f"{STUB_PREFIX}skip-file)"
+
+TERM_RESET = "\033[0m"
+TERM_BOLD = "\033[1m"
+TERM_BLACK = "\033[30m"
+TERM_RED = "\033[31m"
+TERM_GREEN = "\033[32m"
+TERM_YELLOW = "\033[33m"
+TERM_BLUE = "\033[34m"
+TERM_MAGENTA = "\033[35m"
+TERM_CYAN = "\033[36m"
+TERM_WHITE = "\033[37m"
+
+DEFAULT_SOURCE_EXTS = {".py", ".pyi"}
+TY_MAP_DEFAULTS = {
+ "list": "collections.abc.Sequence",
+ "dict": "collections.abc.Mapping",
+}
+
+TY_TO_IMPORT = {
+ "Any": "typing.Any",
+ "Callable": "typing.Callable",
+ "Mapping": "typing.Mapping",
+ "Object": "tvm_ffi.Object",
+ "Tensor": "tvm_ffi.Tensor",
+ "dtype": "tvm_ffi.dtype",
+ "Device": "tvm_ffi.Device",
+}
+
+# TODO(@junrushao): Make it configurable
+MOD_MAP = {
+ "testing": "tvm_ffi.testing",
+ "ffi": "tvm_ffi",
+}
diff --git a/python/tvm_ffi/stub/file_utils.py
b/python/tvm_ffi/stub/file_utils.py
new file mode 100644
index 0000000..6f95025
--- /dev/null
+++ b/python/tvm_ffi/stub/file_utils.py
@@ -0,0 +1,240 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities for parsing and generating stub files for TVM FFI."""
+
+from __future__ import annotations
+
+import dataclasses
+import difflib
+import traceback
+from pathlib import Path
+from typing import Callable, Generator, Iterable, Literal
+
+from . import consts as C
+
+
[email protected]
+class CodeBlock:
+ """A block of code to be generated in a stub file."""
+
+ kind: Literal["global", "object", "ty-map", "import", None]
+ param: str
+ lineno_start: int
+ lineno_end: int | None
+ lines: list[str]
+
+ def __post_init__(self) -> None:
+ """Validate the code block after initialization."""
+ assert self.kind in {"global", "object", "ty-map", "import", None}
+
+ @property
+ def indent(self) -> int:
+ """Calculate the indentation level of the block based on the first
line."""
+ if not self.lines:
+ return 0
+ first_line = self.lines[0]
+ return len(first_line) - len(first_line.lstrip(" "))
+
+ @staticmethod
+ def from_begin_line(lineo: int, line: str) -> CodeBlock:
+ """Parse a line to create a CodeBlock if it contains a stub begin
marker."""
+ if line.startswith(C.STUB_TY_MAP):
+ return CodeBlock(
+ kind="ty-map",
+ param=line[len(C.STUB_TY_MAP) :].strip(),
+ lineno_start=lineo,
+ lineno_end=lineo,
+ lines=[],
+ )
+ assert line.startswith(C.STUB_BEGIN)
+ stub = line[len(C.STUB_BEGIN) :].strip()
+ if stub.startswith("global/"):
+ kind = "global"
+ param = stub[len("global/") :].strip()
+ elif stub.startswith("object/"):
+ kind = "object"
+ param = stub[len("object/") :].strip()
+ elif stub.startswith("ty-map/"):
+ kind = "ty-map"
+ param = stub[len("ty-map/") :].strip()
+ elif stub.startswith("import"):
+ kind = "import"
+ param = ""
+ else:
+ raise ValueError(f"Unknown stub type `{stub}` at line {lineo}")
+ return CodeBlock(
+ kind=kind, # type: ignore[arg-type]
+ param=param,
+ lineno_start=lineo,
+ lineno_end=None,
+ lines=[],
+ )
+
+
[email protected]
+class FileInfo:
+ """Information about a file being processed."""
+
+ path: Path
+ lines: tuple[str, ...]
+ code_blocks: list[CodeBlock]
+
+ def update(self, show_diff: bool, dry_run: bool) -> bool:
+ """Update the file's lines based on the current code blocks and
optionally show a diff."""
+ new_lines = tuple(line for block in self.code_blocks for line in
block.lines)
+ if self.lines == new_lines:
+ return False
+ if show_diff:
+ for line in difflib.unified_diff(self.lines, new_lines,
lineterm=""):
+ # Skip placeholder headers when fromfile/tofile are unspecified
+ if line.startswith("---") or line.startswith("+++"):
+ continue
+ if line.startswith("-") and not line.startswith("---"):
+ print(f"{C.TERM_RED}{line}{C.TERM_RESET}") # Red for
removals
+ elif line.startswith("+") and not line.startswith("+++"):
+ print(f"{C.TERM_GREEN}{line}{C.TERM_RESET}") # Green for
additions
+ elif line.startswith("?"):
+ print(f"{C.TERM_YELLOW}{line}{C.TERM_RESET}") # Yellow
for hints
+ else:
+ print(line)
+ self.lines = new_lines
+ if not dry_run:
+ self.path.write_text("\n".join(self.lines) + "\n",
encoding="utf-8")
+ return True
+
+ @staticmethod
+ def from_file(file: Path) -> FileInfo | None: # noqa: PLR0912
+ """Parse a file to extract code blocks based on stub markers."""
+ assert file.is_file(), f"Expected a file, but got: {file}"
+ file = file.resolve()
+ has_marker = False
+ lines: list[str] = file.read_text(encoding="utf-8").splitlines()
+ for line_no, line in enumerate(lines, start=1):
+ if line.strip().startswith(C.STUB_SKIP_FILE):
+ print(
+ f"{C.TERM_YELLOW}[Skipped] skip-file marker found on line
{line_no}: {file}{C.TERM_RESET}"
+ )
+ return None
+ if line.strip().startswith(C.STUB_PREFIX):
+ has_marker = True
+ if not has_marker:
+ return None
+ del has_marker
+
+ codes: list[CodeBlock] = []
+ code: CodeBlock | None = None
+ for lineno, line in enumerate(lines, 1):
+ clean_line = line.strip()
+ if clean_line.startswith(C.STUB_BEGIN): # Process "#
tvm-ffi-stubgen(begin)"
+ if code is not None:
+ raise ValueError(f"Nested stub not permitted at line
{lineno}")
+ code = CodeBlock.from_begin_line(lineno, clean_line)
+ code.lineno_start = lineno
+ code.lines.append(line)
+ elif clean_line.startswith(C.STUB_END): # Process "#
tvm-ffi-stubgen(end)"
+ if code is None:
+ raise ValueError(f"Unmatched `{C.STUB_END}` found at line
{lineno}")
+ code.lineno_end = lineno
+ code.lines.append(line)
+ codes.append(code)
+ code = None
+ elif clean_line.startswith(C.STUB_TY_MAP): # Process "#
tvm-ffi-stubgen(ty_map)"
+ ty_code = CodeBlock.from_begin_line(lineno, clean_line)
+ ty_code.lineno_end = lineno
+ ty_code.lines.append(line)
+ codes.append(ty_code)
+ del ty_code
+ elif clean_line.startswith(C.STUB_PREFIX):
+ raise ValueError(f"Unknown stub type at line {lineno}:
{clean_line}")
+ elif code is None: # Process a plain line outside of any stub
block
+ codes.append(
+ CodeBlock(
+ kind=None, param="", lineno_start=lineno,
lineno_end=lineno, lines=[line]
+ )
+ )
+ else: # Process a line inside a stub block
+ code.lines.append(line)
+ if code is not None:
+ raise ValueError("Unclosed stub block at end of file")
+ return FileInfo(path=file, lines=tuple(lines), code_blocks=codes)
+
+
+def collect_files(paths: list[Path]) -> list[FileInfo]:
+ """Collect all files from the given paths and parse them into FileInfo
objects."""
+
+ def _on_error(e: Exception) -> None:
+ print(
+ f'{C.TERM_RED}[Failed] File
"{file}"\n{traceback.format_exc()}{C.TERM_RESET}',
+ end="",
+ flush=True,
+ )
+
+ def _walk_recursive() -> Generator[Path, None, None]:
+ for p in paths:
+ if p.is_file():
+ yield p
+ continue
+ for root, _dirs, files in path_walk(p, follow_symlinks=False,
on_error=_on_error):
+ for file in files:
+ f = Path(root) / file
+ if f.suffix.lower() not in C.DEFAULT_SOURCE_EXTS:
+ continue
+ yield f
+
+ filenames = list(_walk_recursive())
+ filenames = sorted(filenames, key=lambda f: str(f))
+ files = []
+ for file in filenames:
+ try:
+ content = FileInfo.from_file(file)
+ except Exception as e:
+ _on_error(e)
+ else:
+ if content is not None:
+ files.append(content)
+ return files
+
+
+def path_walk(
+ p: Path,
+ *,
+ top_down: bool = True,
+ on_error: Callable[[Exception], None] | None = None,
+ follow_symlinks: bool = False,
+) -> Iterable[tuple[Path, list[str], list[str]]]:
+ """Compat wrapper for Path.walk (3.12+) with a fallback for < 3.12."""
+ # Python 3.12+ - just delegate to `Path.walk`
+ if hasattr(p, "walk"):
+ yield from p.walk( # type: ignore[attr-defined]
+ top_down=top_down,
+ on_error=on_error,
+ follow_symlinks=follow_symlinks,
+ )
+ return
+ # Python < 3.12 - use `os.walk``
+ import os # noqa: PLC0415
+
+ for root_str, dirnames, filenames in os.walk(
+ p,
+ topdown=top_down,
+ onerror=on_error,
+ followlinks=follow_symlinks,
+ ):
+ root = Path(root_str)
+ # dirnames and filenames are lists of *names*, not full paths,
+ # just like Path.walk()'s documented behavior.
+ yield root, dirnames, filenames
diff --git a/python/tvm_ffi/stub/stubgen.py b/python/tvm_ffi/stub/stubgen.py
deleted file mode 100644
index d796eb4..0000000
--- a/python/tvm_ffi/stub/stubgen.py
+++ /dev/null
@@ -1,528 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""TVM-FFI Stub Generator (``tvm-ffi-stubgen``).
-
-Overview
---------
-This module powers the ``tvm-ffi-stubgen`` command line tool which generates
-in-place type stubs for Python modules that integrate with the TVM FFI
-ecosystem. It scans ``.py``/``.pyi`` files for special comment markers and
-fills the enclosed blocks with precise, static type annotations derived from
-runtime metadata exposed by TVM FFI.
-
-Why you might use this
-----------------------
-- You author a Python module that binds to C++/C via TVM FFI and want
- high-quality type hints for functions, objects, and methods.
-- You maintain a downstream extension that registers global functions or
- FFI object types and want your Python API surface to be type-checker
- friendly without manually writing stubs.
-
-How it works (in one sentence)
-------------------------------
-``tvm-ffi-stubgen`` replaces the content between special ``tvm-ffi-stubgen``
markers
-with generated code guarded by ``if TYPE_CHECKING: ...`` so that the runtime
-behavior is unchanged while static analyzers get rich types.
-
-Stub block markers
-------------------
-Insert one of the following begin/end markers in your source, then run
-``tvm-ffi-stubgen``. Indentation on the ``begin`` line is preserved; generated
-content is additionally indented by ``--indent`` spaces (default: 4).
-
-1) Global function stubs
-
- Mark all global functions whose names start with a registry prefix
- (e.g. ``ffi`` or ``my_ffi_extension``):
-
- .. code-block:: python
-
- from typing import TYPE_CHECKING
-
- # tvm-ffi-stubgen(begin): global/ffi
- if TYPE_CHECKING:
- # fmt: off
- # (generated by tvm-ffi-stubgen)
- # fmt: on
- # tvm-ffi-stubgen(end)
-
- ``tvm-ffi-stubgen`` expands this with function signatures discovered via
the
- TVM FFI global function registry.
-
-2) Object type stubs
-
- Mark fields and methods for a registered FFI object type using its
- ``type_key`` (the key passed to ``@register_object``):
-
- .. code-block:: python
-
- @register_object("testing.SchemaAllTypes")
- class _SchemaAllTypes:
- # tvm-ffi-stubgen(begin): object/testing.SchemaAllTypes
- # tvm-ffi-stubgen(ty_map): testing.SchemaAllTypes ->
_SchemaAllTypes
- if TYPE_CHECKING:
- # fmt: off
- # (generated by tvm-ffi-stubgen)
- # fmt: on
- # tvm-ffi-stubgen(end)
-
- ``tvm-ffi-stubgen`` expands this with annotated attributes and method stub
- signatures. The special C FFI initializer ``__ffi_init__`` is exposed as
- ``__c_ffi_init__`` to avoid interfering with your Python ``__init__``.
-
-3) Skip whole file
-
- If a source file should never be modified by the stub generator, add the
- following directive anywhere in the file:
-
- .. code-block:: python
-
- # tvm-ffi-stubgen(skip-file)
-
- When present, ``tvm-ffi-stubgen`` skips processing this file entirely. This
- is useful for files that are generated by other tooling or vendored.
-
-Optional type mapping lines
----------------------------
-Inside a stub block you may add mapping hints to rename fully-qualified type
-names to simpler aliases in the generated output:
-
-.. code-block:: python
-
- # tvm-ffi-stubgen(ty_map): A.B.C -> C
- # tvm-ffi-stubgen(ty_map): list -> Sequence
- # tvm-ffi-stubgen(ty_map): dict -> Mapping
-
-By default, ``list`` is shown as ``Sequence`` and ``dict`` as ``Mapping``.
-If you use names such as ``Sequence``/``Mapping``, ensure they are available
-to type checkers in your module, for example:
-
-.. code-block:: python
-
- from typing import TYPE_CHECKING
- if TYPE_CHECKING:
- from collections.abc import Mapping, Sequence
-
-Runtime requirements
---------------------
-- Python must be able to import ``tvm_ffi``.
-- The process needs access to the TVM runtime and any extension libraries that
- provide the global functions or object types you want to stub. Use the
- ``--dlls`` option to preload shared libraries when necessary.
-
-What files are modified
------------------------
-Only files with extensions ``.py`` and ``.pyi`` are scanned. Files are updated
-in place. A colored unified diff is printed for each change.
-
-CLI quick start
----------------
-
-.. code-block:: bash
-
- # Generate stubs for a single file
- tvm-ffi-stubgen python/tvm_ffi/_ffi_api.py
-
- # Recursively scan directories for tvm-ffi-stubgen blocks
- tvm-ffi-stubgen python/tvm_ffi examples/packaging/python/my_ffi_extension
-
- # Preload TVM runtime and your extension library before generation
- tvm-ffi-stubgen \
- --dlls build/libtvm_runtime.dylib build/libmy_ext.dylib \
- python/tvm_ffi/_ffi_api.py
-
-Exit status
------------
-Returns 0 on success and 1 if any file fails to process.
-
-"""
-
-from __future__ import annotations
-
-import argparse
-import ctypes
-import dataclasses
-import difflib
-import logging
-import sys
-from io import StringIO
-from pathlib import Path
-from typing import Callable
-
-from tvm_ffi.core import TypeSchema,
_lookup_or_register_type_info_from_type_key
-from tvm_ffi.registry import get_global_func_metadata, list_global_func_names
-
-DEFAULT_SOURCE_EXTS = {".py", ".pyi"}
-STUB_BEGIN = "# tvm-ffi-stubgen(begin):"
-STUB_END = "# tvm-ffi-stubgen(end)"
-STUB_TY_MAP = "# tvm-ffi-stubgen(ty_map):"
-STUB_SKIP_FILE = "# tvm-ffi-stubgen(skip-file)"
-
-TERM_RESET = "\033[0m"
-TERM_BOLD = "\033[1m"
-TERM_RED = "\033[31m"
-TERM_GREEN = "\033[32m"
-TERM_YELLOW = "\033[33m"
-
-logger = logging.getLogger(__name__)
-logging.basicConfig(level=logging.INFO)
-
-
[email protected]
-class Options:
- """Command line options for stub generation."""
-
- dlls: list[str] = dataclasses.field(default_factory=list)
- indent: int = 4
- files: list[str] = dataclasses.field(default_factory=list)
- suppress_print: bool = False
-
-
[email protected]
-class StubConfig:
- """Configuration of a stub block."""
-
- name: str
- indent: int
- lineno: int
- ty_map: dict[str, str] = dataclasses.field(
- default_factory=lambda: dict(
- {
- "list": "Sequence",
- "dict": "Mapping",
- }
- )
- )
-
-
-def _as_func_signature(
- schema: TypeSchema,
- func_name: str,
- ty_map: Callable[[str], str],
-) -> str:
- buf = StringIO()
- buf.write(f"def {func_name}(")
- if schema.origin != "Callable":
- raise ValueError(f"Expected Callable type schema, but got: {schema}")
- if not schema.args:
- buf.write("*args: Any) -> Any:")
- return buf.getvalue()
- arg_ret = schema.args[0]
- arg_args = schema.args[1:]
- for i, arg in enumerate(arg_args):
- buf.write(f"_{i}: ")
- buf.write(arg.repr(ty_map))
- buf.write(", ")
- if arg_args:
- buf.write("/")
- buf.write(") -> ")
- buf.write(arg_ret.repr(ty_map))
- buf.write(":")
- return buf.getvalue()
-
-
-def _filter_files(paths: list[Path]) -> list[Path]:
- results: list[Path] = []
- for p in paths:
- if not p.exists():
- raise FileNotFoundError(f"Path does not exist: {p}")
- if p.is_dir():
- for f in p.rglob("*"):
- if f.is_file() and f.suffix.lower() in DEFAULT_SOURCE_EXTS:
- results.append(f.resolve())
- continue
- f = p.resolve()
- if f.is_file() and f.suffix.lower() in DEFAULT_SOURCE_EXTS:
- results.append(f)
- # Deterministic order
- return sorted(set(results))
-
-
-def _make_type_map(name_map: dict[str, str]) -> Callable[[str], str]:
- def map_type(name: str) -> str:
- if (ret := name_map.get(name)) is not None:
- return ret
- return name.rsplit(".", 1)[-1]
-
- return map_type
-
-
-def _generate_global(
- stub: StubConfig,
- global_func_tab: dict[str, list[str]],
- opt: Options,
-) -> list[str]:
- assert stub.name.startswith("global/")
- prefix = stub.name[len("global/") :].strip()
- ty_map = _make_type_map(stub.ty_map)
- indent = " " * (stub.indent + opt.indent)
- results: list[str] = [
- " " * stub.indent + "if TYPE_CHECKING:",
- f"{indent}# fmt: off",
- ]
- for name in global_func_tab.get(prefix, []):
- schema_str =
get_global_func_metadata(f"{prefix}.{name}")["type_schema"]
- schema = TypeSchema.from_json_str(schema_str)
- sig = _as_func_signature(schema, name, ty_map=ty_map)
- func = f"{indent}{sig} ..."
- results.append(func)
- if len(results) > 2:
- results.append(f"{indent}# fmt: on")
- else:
- results = []
- return results
-
-
-def _show_diff(old: list[str], new: list[str]) -> None:
- for line in difflib.unified_diff(old, new, lineterm=""):
- # Skip placeholder headers when fromfile/tofile are unspecified
- if line.startswith("---") or line.startswith("+++"):
- continue
- if line.startswith("-") and not line.startswith("---"):
- print(f"{TERM_RED}{line}{TERM_RESET}") # Red for removals
- elif line.startswith("+") and not line.startswith("+++"):
- print(f"{TERM_GREEN}{line}{TERM_RESET}") # Green for additions
- elif line.startswith("?"):
- print(f"{TERM_YELLOW}{line}{TERM_RESET}") # Yellow for hints
- else:
- print(line)
-
-
-def _generate_object(
- stub: StubConfig,
- opt: Options,
-) -> list[str]:
- assert stub.name.startswith("object/")
- type_key = stub.name[len("object/") :].strip()
- ty_map = _make_type_map(stub.ty_map)
- indent = " " * (stub.indent + opt.indent)
- results: list[str] = [
- " " * stub.indent + "if TYPE_CHECKING:",
- f"{indent}# fmt: off",
- ]
-
- type_info = _lookup_or_register_type_info_from_type_key(type_key)
- for field in type_info.fields:
- schema = TypeSchema.from_json_str(field.metadata["type_schema"])
- schema_str = schema.repr(ty_map=ty_map)
- results.append(f"{indent}{field.name}: {schema_str}")
- for method in type_info.methods:
- name = method.name
- if name == "__ffi_init__":
- name = "__c_ffi_init__"
- schema = TypeSchema.from_json_str(method.metadata["type_schema"])
- schema_str = _as_func_signature(schema, name, ty_map=ty_map)
- if method.is_static:
- results.append(f"{indent}@staticmethod")
- results.append(f"{indent}{schema_str} ...")
- if len(results) > 2:
- results.append(f"{indent}# fmt: on")
- else:
- results = []
- return results
-
-
-def _has_skip_file_marker(lines: list[str]) -> bool:
- for raw in lines:
- if raw.strip().startswith(STUB_SKIP_FILE):
- return True
- return False
-
-
-def _main( # noqa: PLR0912, PLR0915
- file: Path,
- opt: Options,
- global_func_tab: dict[str, list[str]] | None = None,
-) -> None:
- assert file.is_file(), f"Expected a file, but got: {file}"
-
- lines_now = file.read_text(encoding="utf-8").splitlines()
-
- # directive(skip-file): skip processing this file entirely if present.
- if _has_skip_file_marker(lines_now):
- if not opt.suppress_print:
- print(f"{TERM_YELLOW}[Skipped] {file}{TERM_RESET}")
- return
-
- if global_func_tab is None:
- global_func_tab = _compute_global_func_tab()
-
- lines_new: list[str] = []
- stub: StubConfig | None = None
- skipped: bool = True
- for lineno, line in enumerate(lines_now, 1):
- clean_line = line.strip()
- if clean_line.startswith(STUB_BEGIN):
- if stub is not None:
- raise ValueError(f"Nested stub not permitted, but found at
{file}:{lineno}")
- stub = StubConfig(
- name=clean_line[len(STUB_BEGIN) :].strip(),
- indent=len(line) - len(clean_line),
- lineno=lineno,
- )
- skipped = False
- lines_new.append(line)
- elif clean_line.startswith(STUB_END):
- if stub is None:
- raise ValueError(f"Unmatched stub end found at
{file}:{lineno}")
- if stub.name.startswith("global/"):
- lines_new.extend(_generate_global(stub, global_func_tab, opt))
- elif stub.name.startswith("object/"):
- lines_new.extend(_generate_object(stub, opt))
- else:
- raise ValueError(f"Unknown stub type `{stub.name}` at
{file}:{stub.lineno}")
- stub = None
- lines_new.append(line)
- elif clean_line.startswith(STUB_TY_MAP):
- if stub is None:
- raise ValueError(f"Stub ty_map outside stub block at
{file}:{lineno}")
- ty_map = clean_line[len(STUB_TY_MAP) :].strip()
- try:
- lhs, rhs = ty_map.split("->")
- except ValueError as e:
- raise ValueError(
- f"Invalid ty_map format at {file}:{lineno}. Example: `A.B
-> C`"
- ) from e
- lhs = lhs.strip()
- rhs = rhs.strip()
- stub.ty_map[lhs] = rhs
- lines_new.append(line)
- elif stub is None:
- lines_new.append(line)
- if stub is not None:
- raise ValueError(f"Unclosed stub block at end of file: {file}")
- if not skipped:
- if lines_now != lines_new:
- if not opt.suppress_print:
- print(f"{TERM_GREEN}[Updated] {file}{TERM_RESET}")
- _show_diff(lines_now, lines_new)
- file.write_text("\n".join(lines_new) + "\n", encoding="utf-8")
- elif not opt.suppress_print:
- print(f"{TERM_BOLD}[Unchanged] {file}{TERM_RESET}")
-
-
-def _compute_global_func_tab() -> dict[str, list[str]]:
- # Build global function table only if we are going to process blocks.
- global_func_tab: dict[str, list[str]] = {}
- for name in list_global_func_names():
- prefix, suffix = name.rsplit(".", 1)
- global_func_tab.setdefault(prefix, []).append(suffix)
- # Ensure stable ordering for deterministic output.
- for k in list(global_func_tab.keys()):
- global_func_tab[k].sort()
- return global_func_tab
-
-
-def __main__() -> int:
- """Command line entry point for ``tvm-ffi-stubgen``.
-
- This generates in-place type stubs inside special ``tvm-ffi-stubgen``
blocks
- in the given files or directories. See the module docstring for an
- overview and examples of the block syntax.
- """
-
- class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter,
argparse.RawTextHelpFormatter):
- pass
-
- parser = argparse.ArgumentParser(
- prog="tvm-ffi-stubgen",
- description=(
- "Generate in-place type stubs for TVM FFI.\n\n"
- "It scans .py/.pyi files for tvm-ffi-stubgen blocks and fills them
with\n"
- "TYPE_CHECKING-only annotations derived from TVM runtime metadata."
- ),
- formatter_class=HelpFormatter,
- epilog=(
- "Examples:\n"
- " # Single file\n"
- " tvm-ffi-stubgen python/tvm_ffi/_ffi_api.py\n\n"
- " # Recursively scan directories\n"
- " tvm-ffi-stubgen python/tvm_ffi
examples/packaging/python/my_ffi_extension\n\n"
- " # Preload TVM runtime / extension libraries\n"
- " tvm-ffi-stubgen --dlls build/libtvm_runtime.so
build/libmy_ext.so my_pkg/_ffi_api.py\n\n"
- "Stub block syntax (placed in your source):\n"
- " # tvm-ffi-stubgen(begin): global/<registry-prefix>\n"
- " ... generated function stubs ...\n"
- " # tvm-ffi-stubgen(end)\n\n"
- " # tvm-ffi-stubgen(begin): object/<type_key>\n"
- " # tvm-ffi-stubgen(ty_map): list -> Sequence\n"
- " # tvm-ffi-stubgen(ty_map): dict -> Mapping\n"
- " ... generated fields and methods ...\n"
- " # tvm-ffi-stubgen(end)\n\n"
- " # Skip a file entirely\n"
- " # tvm-ffi-stubgen(skip-file)\n\n"
- "Tips:\n"
- " - Only .py/.pyi files are updated; directories are scanned
recursively.\n"
- " - Import any aliases you use in ty_map under TYPE_CHECKING,
e.g.\n"
- " from collections.abc import Mapping, Sequence\n"
- " - Use --dlls to preload shared libraries when function/type
metadata\n"
- " is provided by native extensions.\n"
- ),
- )
- parser.add_argument(
- "--dlls",
- nargs="*",
- metavar="LIB",
- help=(
- "Shared libraries to preload before generation (e.g. TVM runtime
or "
- "your extension). This ensures global function and object metadata
"
- "is available. Accepts multiple paths; platform-specific suffixes "
- "like .so/.dylib/.dll are supported."
- ),
- default=[],
- )
- parser.add_argument(
- "--indent",
- type=int,
- default=4,
- help=(
- "Extra spaces added inside each generated block, relative to the "
- "indentation of the corresponding '# tvm-ffi-stubgen(begin):'
line."
- ),
- )
- parser.add_argument(
- "files",
- nargs="*",
- metavar="PATH",
- help=(
- "Files or directories to process. Directories are scanned
recursively; "
- "only .py and .pyi files are modified. Use tvm-ffi-stubgen markers
to "
- "select where stubs are generated."
- ),
- )
- opt = Options(**vars(parser.parse_args()))
- if not opt.files:
- parser.print_help()
- return 1
-
- dlls = [ctypes.CDLL(lib) for lib in opt.dlls]
- global_func_tab = _compute_global_func_tab()
- rc = 0
- try:
- for file in _filter_files([Path(f) for f in opt.files]):
- try:
- _main(file, opt, global_func_tab=global_func_tab)
- except Exception:
- logger.exception(f"{TERM_RED}[Failed] {file}{TERM_RESET}")
- rc = 1
- finally:
- del dlls
- return rc
-
-
-if __name__ == "__main__":
- sys.exit(__main__())
diff --git a/python/tvm_ffi/stub/utils.py b/python/tvm_ffi/stub/utils.py
new file mode 100644
index 0000000..7f461c4
--- /dev/null
+++ b/python/tvm_ffi/stub/utils.py
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Common utilities for the `tvm-ffi-stubgen` tool."""
+
+from __future__ import annotations
+
+import dataclasses
+
+
[email protected]
+class Options:
+ """Command line options for stub generation."""
+
+ dlls: list[str] = dataclasses.field(default_factory=list)
+ indent: int = 4
+ files: list[str] = dataclasses.field(default_factory=list)
+ verbose: bool = False
+ dry_run: bool = False
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index ce9205f..2b157b1 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -15,34 +15,40 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities."""
-# ruff: noqa: D102,D105
+# ruff: noqa: D102,D105
+# tvm-ffi-stubgen(begin): import
+# fmt: off
+# isort: off
from __future__ import annotations
+from typing import Any, TYPE_CHECKING
+if TYPE_CHECKING:
+ from collections.abc import Mapping, Sequence
+ from tvm_ffi import Device, Object, dtype
+# isort: on
+# fmt: on
+# tvm-ffi-stubgen(end)
-from collections.abc import Mapping, Sequence
-from typing import TYPE_CHECKING, Any, ClassVar
+from typing import ClassVar
from . import _ffi_api
from .core import Object
from .dataclasses import c_class, field
from .registry import get_global_func, register_object
-if TYPE_CHECKING:
- from tvm_ffi import Device, dtype
-
@register_object("testing.TestObjectBase")
class TestObjectBase(Object):
"""Test object base class."""
# tvm-ffi-stubgen(begin): object/testing.TestObjectBase
+ # fmt: off
+ v_i64: int
+ v_f64: float
+ v_str: str
if TYPE_CHECKING:
- # fmt: off
- v_i64: int
- v_f64: float
- v_str: str
- def add_i64(_0: TestObjectBase, _1: int, /) -> int: ...
- # fmt: on
+ def add_i64(self, _1: int, /) -> int: ...
+ # fmt: on
# tvm-ffi-stubgen(end)
@@ -51,14 +57,14 @@ class TestIntPair(Object):
"""Test Int Pair."""
# tvm-ffi-stubgen(begin): object/testing.TestIntPair
+ # fmt: off
+ a: int
+ b: int
if TYPE_CHECKING:
- # fmt: off
- a: int
- b: int
@staticmethod
def __c_ffi_init__(_0: int, _1: int, /) -> Object: ...
- def sum(_0: TestIntPair, /) -> int: ...
- # fmt: on
+ def sum(self, /) -> int: ...
+ # fmt: on
# tvm-ffi-stubgen(end)
@@ -67,42 +73,41 @@ class TestObjectDerived(TestObjectBase):
"""Test object derived class."""
# tvm-ffi-stubgen(begin): object/testing.TestObjectDerived
- if TYPE_CHECKING:
- # fmt: off
- v_map: Mapping[Any, Any]
- v_array: Sequence[Any]
- # fmt: on
+ # fmt: off
+ v_map: Mapping[Any, Any]
+ v_array: Sequence[Any]
+ # fmt: on
# tvm-ffi-stubgen(end)
@register_object("testing.SchemaAllTypes")
class _SchemaAllTypes:
+ # tvm-ffi-stubgen(ty-map): testing.SchemaAllTypes ->
testing._SchemaAllTypes
# tvm-ffi-stubgen(begin): object/testing.SchemaAllTypes
- # tvm-ffi-stubgen(ty_map): testing.SchemaAllTypes -> _SchemaAllTypes
+ # fmt: off
+ v_bool: bool
+ v_int: int
+ v_float: float
+ v_device: Device
+ v_dtype: dtype
+ v_string: str
+ v_bytes: bytes
+ v_opt_int: int | None
+ v_opt_str: str | None
+ v_arr_int: Sequence[int]
+ v_arr_str: Sequence[str]
+ v_map_str_int: Mapping[str, int]
+ v_map_str_arr_int: Mapping[str, Sequence[int]]
+ v_variant: str | Sequence[int] | Mapping[str, int]
+ v_opt_arr_variant: Sequence[int | str] | None
if TYPE_CHECKING:
- # fmt: off
- v_bool: bool
- v_int: int
- v_float: float
- v_device: Device
- v_dtype: dtype
- v_string: str
- v_bytes: bytes
- v_opt_int: int | None
- v_opt_str: str | None
- v_arr_int: Sequence[int]
- v_arr_str: Sequence[str]
- v_map_str_int: Mapping[str, int]
- v_map_str_arr_int: Mapping[str, Sequence[int]]
- v_variant: str | Sequence[int] | Mapping[str, int]
- v_opt_arr_variant: Sequence[int | str] | None
- def add_int(_0: _SchemaAllTypes, _1: int, /) -> int: ...
- def append_int(_0: _SchemaAllTypes, _1: Sequence[int], _2: int, /) ->
Sequence[int]: ...
- def maybe_concat(_0: _SchemaAllTypes, _1: str | None, _2: str | None,
/) -> str | None: ...
- def merge_map(_0: _SchemaAllTypes, _1: Mapping[str, Sequence[int]],
_2: Mapping[str, Sequence[int]], /) -> Mapping[str, Sequence[int]]: ...
+ def add_int(self, _1: int, /) -> int: ...
+ def append_int(self, _1: Sequence[int], _2: int, /) -> Sequence[int]:
...
+ def maybe_concat(self, _1: str | None, _2: str | None, /) -> str |
None: ...
+ def merge_map(self, _1: Mapping[str, Sequence[int]], _2: Mapping[str,
Sequence[int]], /) -> Mapping[str, Sequence[int]]: ...
@staticmethod
def make_with(_0: int, _1: float, _2: str, /) -> _SchemaAllTypes: ...
- # fmt: on
+ # fmt: on
# tvm-ffi-stubgen(end)
diff --git a/tests/python/test_stubgen.py b/tests/python/test_stubgen.py
deleted file mode 100644
index 574d2b3..0000000
--- a/tests/python/test_stubgen.py
+++ /dev/null
@@ -1,182 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import re
-from pathlib import Path
-
-import pytest
-from tvm_ffi.stub import stubgen
-
-
-def test_stubgen_skip_file(tmp_path: Path) -> None:
- p: Path = tmp_path / "dummy.py"
- src = (
- "# tvm-ffi-stubgen(skip-file)\n"
- "from typing import TYPE_CHECKING\n\n"
- "# tvm-ffi-stubgen(begin): global/ffi\n"
- "if TYPE_CHECKING:\n"
- " pass\n"
- "# tvm-ffi-stubgen(end)\n"
- )
- p.write_text(src, encoding="utf-8")
- # Run the generator; it should skip without trying to query the registry
- stubgen._main(p, stubgen.Options(indent=4, suppress_print=True))
- # File must be unchanged
- assert p.read_text(encoding="utf-8") == src
-
-
-def test_stubgen_global_block_generates_and_indents(tmp_path: Path) -> None:
- p: Path = tmp_path / "gen_global.py"
- # Indent begin by 2 spaces; inner indent is begin-indent + opt.indent (3)
- src = (
- "from typing import TYPE_CHECKING\n\n"
- " # tvm-ffi-stubgen(begin): global/ffi\n"
- " # tvm-ffi-stubgen(end)\n"
- )
- p.write_text(src, encoding="utf-8")
-
- stubgen._main(p, stubgen.Options(indent=3, suppress_print=True))
-
- out = p.read_text(encoding="utf-8").splitlines()
-
- # Expect TYPE_CHECKING guard with the same begin indentation
- assert any(line == " if TYPE_CHECKING:" for line in out)
- # Expect formatting guards
- assert any(line == " # fmt: off" for line in out)
- assert any(line == " # fmt: on" for line in out)
- # Expect at least one known ffi function signature, e.g. String(...)-> str
- string_lines = [ln for ln in out if
re.search(r"\bdef\s+String\(.*\)\s*->\s*str:\s*\.\.\.", ln)]
- assert string_lines, "Expected stub for ffi.String"
- # Check inner indent equals begin (2) + opt.indent (3) = 5 spaces
- assert all(ln.startswith(" " * 5) for ln in string_lines)
-
- # Idempotency: second run should keep file unchanged
- before = "\n".join(out) + "\n"
- stubgen._main(p, stubgen.Options(indent=3, suppress_print=True))
- assert p.read_text(encoding="utf-8") == before
-
-
-def test_stubgen_global_block_no_matches_is_noop(tmp_path: Path) -> None:
- p: Path = tmp_path / "gen_global_empty.py"
- src = "# tvm-ffi-stubgen(begin): global/this_prefix_does_not_exist\n#
tvm-ffi-stubgen(end)\n"
- p.write_text(src, encoding="utf-8")
- stubgen._main(p, stubgen.Options(indent=4, suppress_print=True))
- assert p.read_text(encoding="utf-8") == src
-
-
-def test_stubgen_object_block_generates_fields_and_methods(tmp_path: Path) ->
None:
- # Ensure object type registrations are loaded
-
- p: Path = tmp_path / "gen_object_pair.py"
- src = (
- "class _C:\n"
- " # tvm-ffi-stubgen(begin): object/testing.TestIntPair\n"
- " # tvm-ffi-stubgen(end)\n"
- )
- p.write_text(src, encoding="utf-8")
-
- stubgen._main(p, stubgen.Options(indent=4, suppress_print=True))
- out = p.read_text(encoding="utf-8").splitlines()
-
- # Fields a and b should be generated
- assert any(" a: int" == ln for ln in out)
- assert any(" b: int" == ln for ln in out)
- # __ffi_init__ should be exposed as __c_ffi_init__ and marked staticmethod
- init_idx = next(i for i, ln in enumerate(out) if "def __c_ffi_init__(" in
ln)
- assert out[init_idx - 1].strip() == "@staticmethod"
-
-
-def test_stubgen_object_block_with_ty_map_and_collections(tmp_path: Path) ->
None:
- # Ensure type info for SchemaAllTypes is available
- p: Path = tmp_path / "gen_object_schema.py"
- src = (
- "# tvm-ffi-stubgen(begin): object/testing.SchemaAllTypes\n"
- "# tvm-ffi-stubgen(ty_map): testing.SchemaAllTypes ->
_SchemaAllTypes\n"
- "# tvm-ffi-stubgen(end)\n"
- )
- p.write_text(src, encoding="utf-8")
-
- stubgen._main(p, stubgen.Options(indent=4, suppress_print=True))
- text = p.read_text(encoding="utf-8")
-
- # Mapped container aliases should appear
- assert "Sequence[int]" in text
- assert "Mapping[str, Sequence[int]]" in text
- # Method types reflect mapping of the object type
- assert re.search(r"def\s+add_int\(_0: _SchemaAllTypes, _1: int,
/\)\s*->\s*int:\s*\.\.\.", text)
- # Static factory returns the mapped type
- assert re.search(
-
r"@staticmethod\s*\n\s*def\s+make_with\(.*\)\s*->\s*_SchemaAllTypes:\s*\.\.\.",
text
- )
-
-
-def test_stubgen_errors_for_invalid_directives(tmp_path: Path) -> None:
- # ty_map outside a block
- p1 = tmp_path / "invalid_ty_map_outside.py"
- p1.write_text("# tvm-ffi-stubgen(ty_map): A.B -> C\n", encoding="utf-8")
- with pytest.raises(ValueError, match="Stub ty_map outside stub block"):
- stubgen._main(p1, stubgen.Options(suppress_print=True))
-
- # invalid ty_map format inside a block
- p2 = tmp_path / "invalid_ty_map_format.py"
- p2.write_text(
- (
- "# tvm-ffi-stubgen(begin): object/testing.TestObjectBase\n"
- "# tvm-ffi-stubgen(ty_map): not_a_map_line\n"
- "# tvm-ffi-stubgen(end)\n"
- ),
- encoding="utf-8",
- )
- with pytest.raises(ValueError, match=r"Invalid ty_map format"):
- stubgen._main(p2, stubgen.Options(suppress_print=True))
-
-
-def test_stubgen_errors_for_block_structure(tmp_path: Path) -> None:
- # Nested stub blocks are not allowed
- p_nested = tmp_path / "nested.py"
- p_nested.write_text(
- (
- "# tvm-ffi-stubgen(begin): global/ffi\n"
- " # tvm-ffi-stubgen(begin): global/ffi\n"
- "# tvm-ffi-stubgen(end)\n"
- ),
- encoding="utf-8",
- )
- with pytest.raises(ValueError, match=r"Nested stub not permitted"):
- stubgen._main(p_nested, stubgen.Options(suppress_print=True))
-
- # Unmatched end
- p_unmatched_end = tmp_path / "unmatched_end.py"
- p_unmatched_end.write_text("# tvm-ffi-stubgen(end)\n", encoding="utf-8")
- with pytest.raises(ValueError, match=r"Unmatched stub end"):
- stubgen._main(p_unmatched_end, stubgen.Options(suppress_print=True))
-
- # Unknown stub type
- p_unknown = tmp_path / "unknown.py"
- p_unknown.write_text(
- ("# tvm-ffi-stubgen(begin): unknown/foo\n# tvm-ffi-stubgen(end)\n"),
- encoding="utf-8",
- )
- with pytest.raises(ValueError, match=r"Unknown stub type"):
- stubgen._main(p_unknown, stubgen.Options(suppress_print=True))
-
- # Unclosed block
- p_unclosed = tmp_path / "unclosed.py"
- p_unclosed.write_text("# tvm-ffi-stubgen(begin): global/ffi\n",
encoding="utf-8")
- with pytest.raises(ValueError, match=r"Unclosed stub block"):
- stubgen._main(p_unclosed, stubgen.Options(suppress_print=True))