This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 92e150b feat(stubgen): Generate `__all__` for proper exporting (#268)
92e150b is described below
commit 92e150b9cb8222eeadf5469eae2e06400e4af850
Author: Junru Shao <[email protected]>
AuthorDate: Sun Nov 16 15:44:58 2025 -0800
feat(stubgen): Generate `__all__` for proper exporting (#268)
This is the last component before introducing whole-package generation.
With the directive below
```python
__all__ = [
# tvm-ffi-stubgen(begin): __all__
# tvm-ffi-stubgen(end)
]
```
, this PR enables generation of all auto-generated global functions and
classes, e.g.:
```diff
__all__ = [
# tvm-ffi-stubgen(begin): __all__
+ "Array",
+ "ArrayGetItem",
+ "ArraySize",
+ "Bytes",
+ "FromJSONGraph",
+ "FromJSONGraphString",
+ "FunctionListGlobalNamesFunctor",
+ "FunctionRemoveGlobal",
+ "GetFirstStructuralMismatch",
+ "GetGlobalFuncMetadata",
+ "GetRegisteredTypeKeys",
+ "MakeObjectFromPackedArgs",
+ "Map",
+ "MapCount",
+ "MapForwardIterFunctor",
+ "MapGetItem",
+ "MapSize",
+ "ModuleClearImports",
+ "ModuleGetFunction",
+ "ModuleGetFunctionDoc",
+ "ModuleGetFunctionMetadata",
+ "ModuleGetKind",
+ "ModuleGetPropertyMask",
+ "ModuleGetWriteFormats",
+ "ModuleImplementsFunction",
+ "ModuleImportModule",
+ "ModuleInspectSource",
+ "ModuleLoadFromFile",
+ "ModuleWriteToFile",
+ "Shape",
+ "String",
+ "StructuralHash",
+ "SystemLib",
+ "ToJSONGraph",
+ "ToJSONGraphString",
# tvm-ffi-stubgen(end)
]
```
This PR also includes necessary refactoring to enable unittests for
stubgen core functionalities without having to rely on CLI.
---
python/tvm_ffi/_ffi_api.py | 41 +++++
python/tvm_ffi/core.pyi | 1 +
python/tvm_ffi/stub/analysis.py | 33 +---
python/tvm_ffi/stub/cli.py | 39 +++-
python/tvm_ffi/stub/codegen.py | 142 +++++----------
python/tvm_ffi/stub/consts.py | 4 +
python/tvm_ffi/stub/file_utils.py | 9 +-
python/tvm_ffi/stub/utils.py | 119 +++++++++++++
tests/python/test_stubgen.py | 361 ++++++++++++++++++++++++++++++++++++++
9 files changed, 620 insertions(+), 129 deletions(-)
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index 1bdef7a..ab6704a 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -73,3 +73,44 @@ if TYPE_CHECKING:
# tvm-ffi-stubgen(end)
registry.init_ffi_api("ffi", __name__)
+
+
+__all__ = [
+ # tvm-ffi-stubgen(begin): __all__
+ "Array",
+ "ArrayGetItem",
+ "ArraySize",
+ "Bytes",
+ "FromJSONGraph",
+ "FromJSONGraphString",
+ "FunctionListGlobalNamesFunctor",
+ "FunctionRemoveGlobal",
+ "GetFirstStructuralMismatch",
+ "GetGlobalFuncMetadata",
+ "GetRegisteredTypeKeys",
+ "MakeObjectFromPackedArgs",
+ "Map",
+ "MapCount",
+ "MapForwardIterFunctor",
+ "MapGetItem",
+ "MapSize",
+ "ModuleClearImports",
+ "ModuleGetFunction",
+ "ModuleGetFunctionDoc",
+ "ModuleGetFunctionMetadata",
+ "ModuleGetKind",
+ "ModuleGetPropertyMask",
+ "ModuleGetWriteFormats",
+ "ModuleImplementsFunction",
+ "ModuleImportModule",
+ "ModuleInspectSource",
+ "ModuleLoadFromFile",
+ "ModuleWriteToFile",
+ "Shape",
+ "String",
+ "StructuralHash",
+ "SystemLib",
+ "ToJSONGraph",
+ "ToJSONGraphString",
+ # tvm-ffi-stubgen(end)
+]
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 261803c..06a78e6 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -230,6 +230,7 @@ class TypeSchema:
origin: str
args: tuple[TypeSchema, ...] = ()
+ def __init__(self, origin: str, args: tuple[TypeSchema, ...] = ()) ->
None: ...
@staticmethod
def from_json_obj(obj: dict[str, Any]) -> TypeSchema: ...
@staticmethod
diff --git a/python/tvm_ffi/stub/analysis.py b/python/tvm_ffi/stub/analysis.py
index 2207397..03dbe36 100644
--- a/python/tvm_ffi/stub/analysis.py
+++ b/python/tvm_ffi/stub/analysis.py
@@ -21,42 +21,21 @@ from __future__ import annotations
from tvm_ffi.registry import list_global_func_names
from . import consts as C
-from .file_utils import FileInfo
-from .utils import Options
+from .utils import FuncInfo
-def collect_global_funcs() -> dict[str, list[str]]:
+def collect_global_funcs() -> dict[str, list[FuncInfo]]:
"""Collect global functions from TVM FFI's global registry."""
# Build global function table only if we are going to process blocks.
- global_funcs: dict[str, list[str]] = {}
+ global_funcs: dict[str, list[FuncInfo]] = {}
for name in list_global_func_names():
try:
- prefix, suffix = name.rsplit(".", 1)
+ prefix, _ = name.rsplit(".", 1)
except ValueError:
print(f"{C.TERM_YELLOW}[Skipped] Invalid name in global function:
{name}{C.TERM_RESET}")
else:
- global_funcs.setdefault(prefix, []).append(suffix)
+ global_funcs.setdefault(prefix,
[]).append(FuncInfo.from_global_name(name))
# Ensure stable ordering for deterministic output.
for k in list(global_funcs.keys()):
- global_funcs[k].sort()
+ global_funcs[k].sort(key=lambda x: x.schema.name)
return global_funcs
-
-
-def collect_ty_maps(files: list[FileInfo], opt: Options) -> dict[str, str]:
- """Collect type maps from the given files."""
- ty_map: dict[str, str] = C.TY_MAP_DEFAULTS.copy()
- for file in files:
- for code in file.code_blocks:
- if code.kind == "ty-map":
- try:
- lhs, rhs = code.param.split("->")
- except ValueError as e:
- raise ValueError(
- f"Invalid ty_map format at line {code.lineno_start}.
Example: `A.B -> C.D`"
- ) from e
- ty_map[lhs.strip()] = rhs.strip()
- if opt.verbose:
- for lhs in sorted(ty_map):
- rhs = ty_map[lhs]
- print(f"{C.TERM_CYAN}[TY-MAP] {lhs} -> {rhs}{C.TERM_RESET}")
- return ty_map
diff --git a/python/tvm_ffi/stub/cli.py b/python/tvm_ffi/stub/cli.py
index 5728e12..85697f5 100644
--- a/python/tvm_ffi/stub/cli.py
+++ b/python/tvm_ffi/stub/cli.py
@@ -28,7 +28,7 @@ from typing import Callable
from . import codegen as G
from . import consts as C
-from .analysis import collect_global_funcs, collect_ty_maps
+from .analysis import collect_global_funcs
from .file_utils import FileInfo, collect_files
from .utils import Options
@@ -56,17 +56,37 @@ def __main__() -> int:
"""
opt = _parse_args()
dlls = [ctypes.CDLL(lib) for lib in opt.dlls]
- global_funcs = collect_global_funcs()
files: list[FileInfo] = collect_files([Path(f) for f in opt.files])
# Stage 1: Process `tvm-ffi-stubgen(ty-map)`
- ty_map: dict[str, str] = collect_ty_maps(files, opt)
+ ty_map: dict[str, str] = C.TY_MAP_DEFAULTS.copy()
+
+ def _stage_1(file: FileInfo) -> None:
+ for code in file.code_blocks:
+ if code.kind == "ty-map":
+ try:
+ lhs, rhs = code.param.split("->")
+ except ValueError as e:
+ raise ValueError(
+ f"Invalid ty_map format at line {code.lineno_start}.
Example: `A.B -> C.D`"
+ ) from e
+ ty_map[lhs.strip()] = rhs.strip()
+
+ for file in files:
+ try:
+ _stage_1(file)
+ except Exception:
+ print(
+ f'{C.TERM_RED}[Failed] File "{file.path}":
{traceback.format_exc()}{C.TERM_RESET}'
+ )
# Stage 2: Process
# - `tvm-ffi-stubgen(begin): global/...`
# - `tvm-ffi-stubgen(begin): object/...`
+ global_funcs = collect_global_funcs()
def _stage_2(file: FileInfo) -> None:
+ all_defined = set()
if opt.verbose:
print(f"{C.TERM_CYAN}[File] {file.path}{C.TERM_RESET}")
ty_used: set[str] = set()
@@ -75,17 +95,26 @@ def __main__() -> int:
# Stage 2.1. Process `tvm-ffi-stubgen(begin): global/...`
for code in file.code_blocks:
if code.kind == "global":
- G.generate_global_funcs(code, global_funcs, fn_ty_map_fn, opt)
+ funcs = global_funcs.get(code.param, [])
+ for func in funcs:
+ all_defined.add(func.schema.name)
+ G.generate_global_funcs(code, funcs, fn_ty_map_fn, opt)
# Stage 2.2. Process `tvm-ffi-stubgen(begin): object/...`
for code in file.code_blocks:
if code.kind == "object":
+ type_key = code.param
+ ty_on_file.add(ty_map.get(type_key, type_key))
G.generate_object(code, fn_ty_map_fn, opt)
- ty_on_file.add(ty_map.get(code.param, code.param))
# Stage 2.3. Add imports for used types.
for code in file.code_blocks:
if code.kind == "import":
G.generate_imports(code, ty_used - ty_on_file, opt)
break # Only one import block per file is supported for now.
+ # Stage 2.4. Add `__all__` for defined classes and functions.
+ for code in file.code_blocks:
+ if code.kind == "__all__":
+ G.generate_all(code, all_defined | ty_on_file, opt)
+ break # Only one __all__ block per file is supported for now.
file.update(show_diff=opt.verbose, dry_run=opt.dry_run)
for file in files:
diff --git a/python/tvm_ffi/stub/codegen.py b/python/tvm_ffi/stub/codegen.py
index be33f16..c15624a 100644
--- a/python/tvm_ffi/stub/codegen.py
+++ b/python/tvm_ffi/stub/codegen.py
@@ -18,77 +18,36 @@
from __future__ import annotations
-from io import StringIO
from typing import Callable
-from tvm_ffi.core import TypeSchema,
_lookup_or_register_type_info_from_type_key
-from tvm_ffi.registry import get_global_func_metadata
-
from . import consts as C
from .file_utils import CodeBlock
-from .utils import Options
-
-
-def generate_func_signature(
- schema: TypeSchema,
- func_name: str,
- ty_map: Callable[[str], str],
- is_member: bool,
-) -> str:
- """Generate a function signature string from a TypeSchema."""
- buf = StringIO()
- buf.write(f"def {func_name}(")
- if schema.origin != "Callable":
- raise ValueError(f"Expected Callable type schema, but got: {schema}")
- if not schema.args:
- ty_map("Any")
- buf.write("*args: Any) -> Any: ...")
- return buf.getvalue()
- arg_ret = schema.args[0]
- arg_args = schema.args[1:]
- for i, arg in enumerate(arg_args):
- if is_member and i == 0:
- buf.write("self, ")
- else:
- buf.write(f"_{i}: ")
- buf.write(arg.repr(ty_map))
- buf.write(", ")
- if arg_args:
- buf.write("/")
- buf.write(") -> ")
- buf.write(arg_ret.repr(ty_map))
- buf.write(": ...")
- return buf.getvalue()
+from .utils import FuncInfo, ObjectInfo, Options
def generate_global_funcs(
- code: CodeBlock,
- global_funcs: dict[str, list[str]],
- fn_ty_map: Callable[[str], str],
- opt: Options,
+ code: CodeBlock, global_funcs: list[FuncInfo], fn_ty_map: Callable[[str],
str], opt: Options
) -> None:
"""Generate function signatures for global functions."""
assert len(code.lines) >= 2
- indent = " " * code.indent
- indent_long = " " * (code.indent + opt.indent)
- prefix = code.param
+ if not global_funcs:
+ return
results: list[str] = [
- generate_func_signature(
-
TypeSchema.from_json_str(get_global_func_metadata(f"{prefix}.{name}")["type_schema"]),
- name,
- ty_map=fn_ty_map,
- is_member=False,
- )
- for name in global_funcs.get(prefix, [])
+ "# fmt: off",
+ "if TYPE_CHECKING:",
+ *[
+ func.gen(
+ fn_ty_map,
+ indent=opt.indent,
+ )
+ for func in global_funcs
+ ],
+ "# fmt: on",
]
- if not results:
- return
+ indent = " " * code.indent
code.lines = [
code.lines[0],
- f"{indent}# fmt: off",
- f"{indent}if TYPE_CHECKING:",
- *[indent_long + sig for sig in results],
- f"{indent}# fmt: on",
+ *[indent + line for line in results],
code.lines[-1],
]
@@ -96,50 +55,30 @@ def generate_global_funcs(
def generate_object(code: CodeBlock, fn_ty_map: Callable[[str], str], opt:
Options) -> None:
"""Generate a class definition for an object type."""
assert len(code.lines) >= 2
- type_key = code.param
- type_info = _lookup_or_register_type_info_from_type_key(type_key)
+ info = ObjectInfo.from_type_key(code.param)
+ if info.methods:
+ results = [
+ "# fmt: off",
+ *info.gen_fields(fn_ty_map, indent=0),
+ "if TYPE_CHECKING:",
+ *info.gen_methods(fn_ty_map, indent=opt.indent),
+ "# fmt: on",
+ ]
+ else:
+ results = [
+ "# fmt: off",
+ *info.gen_fields(fn_ty_map, indent=0),
+ "# fmt: on",
+ ]
indent = " " * code.indent
- indent_long = " " * (code.indent + opt.indent)
-
- fields: list[str] = []
- for field in type_info.fields:
- fields.append(
- f"{indent}{field.name}: "
- +
TypeSchema.from_json_str(field.metadata["type_schema"]).repr(fn_ty_map)
- )
-
- methods: list[str] = []
- if type_info.methods:
- methods = [f"{indent}if TYPE_CHECKING:"]
- for method in type_info.methods:
- if method.is_static:
- methods.append(f"{indent_long}@staticmethod")
- methods.append(
- indent_long
- + generate_func_signature(
- TypeSchema.from_json_str(method.metadata["type_schema"]),
- {
- "__ffi_init__": "__c_ffi_init__",
- }.get(method.name, method.name),
- fn_ty_map,
- is_member=not method.is_static,
- )
- )
code.lines = [
code.lines[0],
- f"{indent}# fmt: off",
- *fields,
- *methods,
- f"{indent}# fmt: on",
+ *[indent + line for line in results],
code.lines[-1],
]
-def generate_imports(
- code: CodeBlock,
- ty_used: set[str],
- opt: Options,
-) -> None:
+def generate_imports(code: CodeBlock, ty_used: set[str], opt: Options) -> None:
"""Generate import statements for the types used in the stub."""
ty_collected: dict[str, list[str]] = {}
for ty in ty_used:
@@ -181,3 +120,18 @@ def generate_imports(
"# fmt: on",
code.lines[-1],
]
+
+
+def generate_all(code: CodeBlock, names: set[str], opt: Options) -> None:
+ """Generate an `__all__` variable for the given names."""
+ assert len(code.lines) >= 2
+ if not names:
+ return
+
+ indent = " " * code.indent
+ names = {f.rsplit(".", 1)[-1] for f in names}
+ code.lines = [
+ code.lines[0],
+ *[f'{indent}"{name}",' for name in sorted(names)],
+ code.lines[-1],
+ ]
diff --git a/python/tvm_ffi/stub/consts.py b/python/tvm_ffi/stub/consts.py
index 6443c95..6922254 100644
--- a/python/tvm_ffi/stub/consts.py
+++ b/python/tvm_ffi/stub/consts.py
@@ -54,3 +54,7 @@ MOD_MAP = {
"testing": "tvm_ffi.testing",
"ffi": "tvm_ffi",
}
+
+FN_NAME_MAP = {
+ "__ffi_init__": "__c_ffi_init__",
+}
diff --git a/python/tvm_ffi/stub/file_utils.py
b/python/tvm_ffi/stub/file_utils.py
index 6f95025..f100c55 100644
--- a/python/tvm_ffi/stub/file_utils.py
+++ b/python/tvm_ffi/stub/file_utils.py
@@ -31,7 +31,7 @@ from . import consts as C
class CodeBlock:
"""A block of code to be generated in a stub file."""
- kind: Literal["global", "object", "ty-map", "import", None]
+ kind: Literal["global", "object", "ty-map", "import", "__all__", None]
param: str
lineno_start: int
lineno_end: int | None
@@ -39,7 +39,7 @@ class CodeBlock:
def __post_init__(self) -> None:
"""Validate the code block after initialization."""
- assert self.kind in {"global", "object", "ty-map", "import", None}
+ assert self.kind in {"global", "object", "ty-map", "import",
"__all__", None}
@property
def indent(self) -> int:
@@ -74,6 +74,9 @@ class CodeBlock:
elif stub.startswith("import"):
kind = "import"
param = ""
+ elif stub == "__all__":
+ kind = "__all__"
+ param = ""
else:
raise ValueError(f"Unknown stub type `{stub}` at line {lineo}")
return CodeBlock(
@@ -178,7 +181,7 @@ def collect_files(paths: list[Path]) -> list[FileInfo]:
def _on_error(e: Exception) -> None:
print(
- f'{C.TERM_RED}[Failed] File
"{file}"\n{traceback.format_exc()}{C.TERM_RESET}',
+ f"{C.TERM_RED}[Error]\n{traceback.format_exc()}{C.TERM_RESET}",
end="",
flush=True,
)
diff --git a/python/tvm_ffi/stub/utils.py b/python/tvm_ffi/stub/utils.py
index 7f461c4..e8beb41 100644
--- a/python/tvm_ffi/stub/utils.py
+++ b/python/tvm_ffi/stub/utils.py
@@ -19,6 +19,12 @@
from __future__ import annotations
import dataclasses
+from io import StringIO
+from typing import Callable
+
+from tvm_ffi.core import TypeSchema
+
+from . import consts as C
@dataclasses.dataclass
@@ -30,3 +36,116 @@ class Options:
files: list[str] = dataclasses.field(default_factory=list)
verbose: bool = False
dry_run: bool = False
+
+
[email protected](init=False)
+class NamedTypeSchema(TypeSchema):
+ """A type schema with an associated name."""
+
+ name: str
+
+ def __init__(self, name: str, schema: TypeSchema) -> None:
+ """Initialize a `NamedTypeSchema` with the given name and type
schema."""
+ super().__init__(origin=schema.origin, args=schema.args)
+ self.name = name
+
+
[email protected]
+class FuncInfo:
+ """Information of a function."""
+
+ schema: NamedTypeSchema
+ is_member: bool
+
+ @staticmethod
+ def from_global_name(name: str) -> FuncInfo:
+ """Construct a `FuncInfo` from a string name of this global
function."""
+ from tvm_ffi.registry import get_global_func_metadata # noqa: PLC0415
+
+ return FuncInfo(
+ schema=NamedTypeSchema(
+ name=name,
+
schema=TypeSchema.from_json_str(get_global_func_metadata(name)["type_schema"]),
+ ),
+ is_member=False,
+ )
+
+ def gen(self, ty_map: Callable[[str], str], indent: int) -> str:
+ """Generate a function signature string for this function."""
+ try:
+ _, func_name = self.schema.name.rsplit(".", 1)
+ except ValueError:
+ func_name = self.schema.name
+ buf = StringIO()
+ buf.write(" " * indent)
+ buf.write(f"def {func_name}(")
+ if self.schema.origin != "Callable":
+ raise ValueError(f"Expected Callable type schema, but got:
{self.schema}")
+ if not self.schema.args:
+ ty_map("Any")
+ buf.write("*args: Any) -> Any: ...")
+ return buf.getvalue()
+ arg_ret = self.schema.args[0]
+ arg_args = self.schema.args[1:]
+ for i, arg in enumerate(arg_args):
+ if self.is_member and i == 0:
+ buf.write("self, ")
+ else:
+ buf.write(f"_{i}: ")
+ buf.write(arg.repr(ty_map))
+ buf.write(", ")
+ if arg_args:
+ buf.write("/")
+ buf.write(") -> ")
+ buf.write(arg_ret.repr(ty_map))
+ buf.write(": ...")
+ return buf.getvalue()
+
+
[email protected]
+class ObjectInfo:
+ """Information of an object type, including its fields and methods."""
+
+ fields: list[NamedTypeSchema]
+ methods: list[FuncInfo]
+
+ @staticmethod
+ def from_type_key(type_key: str) -> ObjectInfo:
+ """Construct an `ObjectInfo` from a type key."""
+ from tvm_ffi.core import _lookup_or_register_type_info_from_type_key
# noqa: PLC0415
+
+ type_info = _lookup_or_register_type_info_from_type_key(type_key)
+ return ObjectInfo(
+ fields=[
+ NamedTypeSchema(
+ name=field.name,
+
schema=TypeSchema.from_json_str(field.metadata["type_schema"]),
+ )
+ for field in type_info.fields
+ ],
+ methods=[
+ FuncInfo(
+ schema=NamedTypeSchema(
+ name=C.FN_NAME_MAP.get(method.name, method.name),
+
schema=TypeSchema.from_json_str(method.metadata["type_schema"]),
+ ),
+ is_member=not method.is_static,
+ )
+ for method in type_info.methods
+ ],
+ )
+
+ def gen_fields(self, ty_map: Callable[[str], str], indent: int) ->
list[str]:
+ """Generate field definitions for this object."""
+ indent_str = " " * indent
+ return [f"{indent_str}{field.name}: {field.repr(ty_map)}" for field in
self.fields]
+
+ def gen_methods(self, ty_map: Callable[[str], str], indent: int) ->
list[str]:
+ """Generate method definitions for this object."""
+ indent_str = " " * indent
+ ret = []
+ for method in self.methods:
+ if not method.is_member:
+ ret.append(f"{indent_str}@staticmethod")
+ ret.append(method.gen(ty_map, indent))
+ return ret
diff --git a/tests/python/test_stubgen.py b/tests/python/test_stubgen.py
new file mode 100644
index 0000000..ef53143
--- /dev/null
+++ b/tests/python/test_stubgen.py
@@ -0,0 +1,361 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from pathlib import Path
+
+import pytest
+from tvm_ffi.core import TypeSchema
+from tvm_ffi.stub import consts as C
+from tvm_ffi.stub.codegen import (
+ generate_all,
+ generate_global_funcs,
+ generate_imports,
+ generate_object,
+)
+from tvm_ffi.stub.file_utils import CodeBlock, FileInfo
+from tvm_ffi.stub.utils import FuncInfo, NamedTypeSchema, ObjectInfo, Options
+
+
+def _identity_ty_map(name: str) -> str:
+ return name
+
+
+def test_codeblock_from_begin_line_variants() -> None:
+ cases = [
+ (f"{C.STUB_BEGIN} global/example", "global", "example"),
+ (f"{C.STUB_BEGIN} object/testing.TestObjectBase", "object",
"testing.TestObjectBase"),
+ (f"{C.STUB_BEGIN} ty-map/custom", "ty-map", "custom"),
+ (f"{C.STUB_BEGIN} import", "import", ""),
+ ]
+ for lineno, (line, kind, param) in enumerate(cases, start=1):
+ block = CodeBlock.from_begin_line(lineno, line)
+ assert block.kind == kind
+ assert block.param == param
+ assert block.lineno_start == lineno
+ assert block.lineno_end is None
+ assert block.lines == []
+
+
+def test_codeblock_from_begin_line_ty_map_and_unknown() -> None:
+ line = f"{C.STUB_TY_MAP} custom -> mapped"
+ block = CodeBlock.from_begin_line(5, line)
+ assert block.kind == "ty-map"
+ assert block.param == "custom -> mapped"
+ assert block.lineno_start == 5
+ assert block.lineno_end == 5
+
+ with pytest.raises(ValueError):
+ CodeBlock.from_begin_line(1, f"{C.STUB_BEGIN} unsupported/kind")
+
+
+def test_fileinfo_from_file_skip_and_missing_markers(tmp_path: Path) -> None:
+ skip = tmp_path / "skip.py"
+ skip.write_text(f"print('hi')\n{C.STUB_SKIP_FILE}\n", encoding="utf-8")
+ assert FileInfo.from_file(skip) is None
+
+ plain = tmp_path / "plain.py"
+ plain.write_text("print('plain')\n", encoding="utf-8")
+ assert FileInfo.from_file(plain) is None
+
+
+def test_fileinfo_from_file_parses_blocks(tmp_path: Path) -> None:
+ content = "\n".join(
+ [
+ "first = 1",
+ f"{C.STUB_BEGIN} global/demo.func",
+ "in_stub = True",
+ C.STUB_END,
+ f"{C.STUB_TY_MAP} x -> y",
+ ]
+ )
+ path = tmp_path / "demo.py"
+ path.write_text(content, encoding="utf-8")
+
+ info = FileInfo.from_file(path)
+ assert info is not None
+ assert info.path == path.resolve()
+ assert len(info.code_blocks) == 3
+
+ first, stub, ty_map = info.code_blocks
+ assert first.kind is None and first.lines == ["first = 1"]
+
+ assert stub.kind == "global"
+ assert stub.param == "demo.func"
+ assert stub.lineno_start == 2
+ assert stub.lineno_end == 4
+ assert stub.lines == [
+ f"{C.STUB_BEGIN} global/demo.func",
+ "in_stub = True",
+ C.STUB_END,
+ ]
+
+ assert ty_map.kind == "ty-map"
+ assert ty_map.param == "x -> y"
+ assert ty_map.lineno_start == ty_map.lineno_end == 5
+ assert ty_map.lines == [f"{C.STUB_TY_MAP} x -> y"]
+
+
+def test_fileinfo_from_file_error_paths(tmp_path: Path) -> None:
+ nested = tmp_path / "nested.py"
+ nested.write_text(
+ "\n".join(
+ [
+ f"{C.STUB_BEGIN} global/outer",
+ f"{C.STUB_BEGIN} global/inner",
+ ]
+ ),
+ encoding="utf-8",
+ )
+ with pytest.raises(ValueError, match="Nested stub not permitted"):
+ FileInfo.from_file(nested)
+
+ unmatched_end = tmp_path / "unmatched.py"
+ unmatched_end.write_text(C.STUB_END + "\n", encoding="utf-8")
+ with pytest.raises(ValueError, match="Unmatched"):
+ FileInfo.from_file(unmatched_end)
+
+ unclosed = tmp_path / "unclosed.py"
+ unclosed.write_text(f"{C.STUB_BEGIN} global/method\n", encoding="utf-8")
+ with pytest.raises(ValueError, match="Unclosed stub block"):
+ FileInfo.from_file(unclosed)
+
+
+def test_funcinfo_gen_variants() -> None:
+ called: list[str] = []
+
+ def ty_map(name: str) -> str:
+ called.append(name)
+ return name
+
+ schema_no_args = NamedTypeSchema("demo.no_args", TypeSchema("Callable",
()))
+ func = FuncInfo(schema=schema_no_args, is_member=False)
+ assert func.gen(ty_map, indent=2) == " def no_args(*args: Any) -> Any:
..."
+ assert called == ["Any"]
+
+ schema_member = NamedTypeSchema(
+ "pkg.Class.method",
+ TypeSchema(
+ "Callable",
+ (
+ TypeSchema("str"),
+ TypeSchema("int"),
+ TypeSchema("float"),
+ ),
+ ),
+ )
+ member_func = FuncInfo(schema=schema_member, is_member=True)
+ assert (
+ member_func.gen(_identity_ty_map, indent=0) == "def method(self, _1:
float, /) -> str: ..."
+ )
+
+ schema_bad = NamedTypeSchema("bad", TypeSchema("int"))
+ with pytest.raises(ValueError):
+ FuncInfo(schema=schema_bad, is_member=False).gen(_identity_ty_map,
indent=0)
+
+
+def test_objectinfo_gen_fields_and_methods() -> None:
+ ty_calls: list[str] = []
+
+ def ty_map(name: str) -> str:
+ ty_calls.append(name)
+ return {"list": "Sequence", "dict": "Mapping"}.get(name, name)
+
+ info = ObjectInfo(
+ fields=[
+ NamedTypeSchema("field_a", TypeSchema("list",
(TypeSchema("int"),))),
+ NamedTypeSchema(
+ "field_b", TypeSchema("dict", (TypeSchema("str"),
TypeSchema("float")))
+ ),
+ ],
+ methods=[
+ FuncInfo(
+ schema=NamedTypeSchema("demo.static", TypeSchema("Callable",
(TypeSchema("int"),))),
+ is_member=False,
+ ),
+ FuncInfo(
+ schema=NamedTypeSchema(
+ "demo.member",
+ TypeSchema("Callable", (TypeSchema("str"),
TypeSchema("bytes"))),
+ ),
+ is_member=True,
+ ),
+ ],
+ )
+
+ assert info.gen_fields(ty_map, indent=2) == [
+ " field_a: Sequence[int]",
+ " field_b: Mapping[str, float]",
+ ]
+ assert ty_calls.count("list") == 1 and ty_calls.count("dict") == 1
+
+ methods = info.gen_methods(_identity_ty_map, indent=2)
+ assert methods == [
+ " @staticmethod",
+ " def static() -> int: ...",
+ " def member(self, /) -> str: ...",
+ ]
+
+
+def test_generate_global_funcs_updates_block() -> None:
+ code = CodeBlock(
+ kind="global",
+ param="testing",
+ lineno_start=1,
+ lineno_end=2,
+ lines=[f"{C.STUB_BEGIN} global/testing", C.STUB_END],
+ )
+ funcs = [
+ FuncInfo(
+ schema=NamedTypeSchema(
+ "testing.add_one",
+ TypeSchema("Callable", (TypeSchema("int"), TypeSchema("int"))),
+ ),
+ is_member=False,
+ )
+ ]
+ opts = Options(indent=2)
+ generate_global_funcs(code, funcs, _identity_ty_map, opts)
+ assert code.lines == [
+ f"{C.STUB_BEGIN} global/testing",
+ "# fmt: off",
+ "if TYPE_CHECKING:",
+ " def add_one(_0: int, /) -> int: ...",
+ "# fmt: on",
+ C.STUB_END,
+ ]
+
+
+def test_generate_global_funcs_noop_on_empty_list() -> None:
+ code = CodeBlock(
+ kind="global",
+ param="empty",
+ lineno_start=1,
+ lineno_end=2,
+ lines=[f"{C.STUB_BEGIN} global/empty", C.STUB_END],
+ )
+ generate_global_funcs(code, [], _identity_ty_map, Options())
+ assert code.lines == [f"{C.STUB_BEGIN} global/empty", C.STUB_END]
+
+
+def test_generate_object_fields_only_block() -> None:
+ code = CodeBlock(
+ kind="object",
+ param="testing.TestObjectDerived",
+ lineno_start=1,
+ lineno_end=2,
+ lines=[f"{C.STUB_BEGIN} object/testing.TestObjectDerived", C.STUB_END],
+ )
+ opts = Options(indent=4)
+ generate_object(code, _identity_ty_map, opts)
+
+ info = ObjectInfo.from_type_key("testing.TestObjectDerived")
+ expected = [
+ f"{C.STUB_BEGIN} object/testing.TestObjectDerived",
+ " " * code.indent + "# fmt: off",
+ *[(" " * code.indent) + line for line in
info.gen_fields(_identity_ty_map, indent=0)],
+ " " * code.indent + "# fmt: on",
+ C.STUB_END,
+ ]
+ assert code.lines == expected
+
+
+def test_generate_object_with_methods() -> None:
+ code = CodeBlock(
+ kind="object",
+ param="testing.TestIntPair",
+ lineno_start=1,
+ lineno_end=2,
+ lines=[f"{C.STUB_BEGIN} object/testing.TestIntPair", C.STUB_END],
+ )
+ opts = Options(indent=4)
+ generate_object(code, _identity_ty_map, opts)
+
+ assert code.lines[0] == f"{C.STUB_BEGIN} object/testing.TestIntPair"
+ assert code.lines[-1] == C.STUB_END
+ assert "# fmt: off" in code.lines[1]
+ assert any("if TYPE_CHECKING:" in line for line in code.lines)
+ method_lines = [
+ line for line in code.lines if "def __c_ffi_init__" in line or "def
sum" in line
+ ]
+ assert any(line.strip().startswith("def __c_ffi_init__") for line in
method_lines)
+ assert any(line.strip().startswith("def sum") for line in method_lines)
+
+
+def test_generate_imports_groups_modules() -> None:
+ code = CodeBlock(
+ kind="import",
+ param="",
+ lineno_start=1,
+ lineno_end=2,
+ lines=[f"{C.STUB_BEGIN} import", C.STUB_END],
+ )
+ ty_used = {
+ "typing.Any",
+ "tvm_ffi.Tensor",
+ "testing.TestObjectBase",
+ "custom.mod.Type",
+ }
+ opts = Options(indent=4)
+ generate_imports(code, ty_used, opts)
+
+ expected_prefix = [
+ f"{C.STUB_BEGIN} import",
+ "# fmt: off",
+ "# isort: off",
+ "from __future__ import annotations",
+ "from typing import Any, TYPE_CHECKING",
+ "if TYPE_CHECKING:",
+ ]
+ assert code.lines[: len(expected_prefix)] == expected_prefix
+ assert " from tvm_ffi.testing import TestObjectBase" in code.lines
+ assert " from tvm_ffi import Tensor" in code.lines
+ assert " from custom.mod import Type" in code.lines
+ assert code.lines[-2:] == ["# fmt: on", C.STUB_END]
+
+
+def test_generate_all_builds_sorted_and_deduped_list() -> None:
+ code = CodeBlock(
+ kind="global",
+ param="all",
+ lineno_start=1,
+ lineno_end=2,
+ lines=[" " + C.STUB_BEGIN + " global/all", C.STUB_END],
+ )
+ generate_all(
+ code,
+ names={"tvm_ffi.foo", "bar", "pkg.baz", "bar"}, # duplicates stripped
+ opt=Options(indent=2),
+ )
+ assert code.lines == [
+ " " + C.STUB_BEGIN + " global/all",
+ ' "bar",',
+ ' "baz",',
+ ' "foo",',
+ C.STUB_END,
+ ]
+
+
+def test_generate_all_noop_on_empty_names() -> None:
+ code = CodeBlock(
+ kind="global",
+ param="all-empty",
+ lineno_start=1,
+ lineno_end=2,
+ lines=[C.STUB_BEGIN + " global/all-empty", C.STUB_END],
+ )
+ before = list(code.lines)
+ generate_all(code, names=set(), opt=Options())
+ assert code.lines == before