This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 92e150b  feat(stubgen): Generate `__all__` for proper exporting (#268)
92e150b is described below

commit 92e150b9cb8222eeadf5469eae2e06400e4af850
Author: Junru Shao <[email protected]>
AuthorDate: Sun Nov 16 15:44:58 2025 -0800

    feat(stubgen): Generate `__all__` for proper exporting (#268)
    
    This is the last component before introducing whole-package generation.
    
    With the directive below
    
    ```python
    __all__ = [
        # tvm-ffi-stubgen(begin): __all__
        # tvm-ffi-stubgen(end)
    ]
    ```
    
    , this PR enables generation of all auto-generated global functions and
    classes, e.g.:
    
    ```diff
    __all__ = [
        # tvm-ffi-stubgen(begin): __all__
    +    "Array",
    +    "ArrayGetItem",
    +    "ArraySize",
    +    "Bytes",
    +    "FromJSONGraph",
    +    "FromJSONGraphString",
    +    "FunctionListGlobalNamesFunctor",
    +    "FunctionRemoveGlobal",
    +    "GetFirstStructuralMismatch",
    +    "GetGlobalFuncMetadata",
    +    "GetRegisteredTypeKeys",
    +    "MakeObjectFromPackedArgs",
    +    "Map",
    +    "MapCount",
    +    "MapForwardIterFunctor",
    +    "MapGetItem",
    +    "MapSize",
    +    "ModuleClearImports",
    +    "ModuleGetFunction",
    +    "ModuleGetFunctionDoc",
    +    "ModuleGetFunctionMetadata",
    +    "ModuleGetKind",
    +    "ModuleGetPropertyMask",
    +    "ModuleGetWriteFormats",
    +    "ModuleImplementsFunction",
    +    "ModuleImportModule",
    +    "ModuleInspectSource",
    +    "ModuleLoadFromFile",
    +    "ModuleWriteToFile",
    +    "Shape",
    +    "String",
    +    "StructuralHash",
    +    "SystemLib",
    +    "ToJSONGraph",
    +    "ToJSONGraphString",
        # tvm-ffi-stubgen(end)
    ]
    ```
    
    This PR also includes necessary refactoring to enable unittests for
    stubgen core functionalities without having to rely on CLI.
---
 python/tvm_ffi/_ffi_api.py        |  41 +++++
 python/tvm_ffi/core.pyi           |   1 +
 python/tvm_ffi/stub/analysis.py   |  33 +---
 python/tvm_ffi/stub/cli.py        |  39 +++-
 python/tvm_ffi/stub/codegen.py    | 142 +++++----------
 python/tvm_ffi/stub/consts.py     |   4 +
 python/tvm_ffi/stub/file_utils.py |   9 +-
 python/tvm_ffi/stub/utils.py      | 119 +++++++++++++
 tests/python/test_stubgen.py      | 361 ++++++++++++++++++++++++++++++++++++++
 9 files changed, 620 insertions(+), 129 deletions(-)

diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index 1bdef7a..ab6704a 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -73,3 +73,44 @@ if TYPE_CHECKING:
 # tvm-ffi-stubgen(end)
 
 registry.init_ffi_api("ffi", __name__)
+
+
+__all__ = [
+    # tvm-ffi-stubgen(begin): __all__
+    "Array",
+    "ArrayGetItem",
+    "ArraySize",
+    "Bytes",
+    "FromJSONGraph",
+    "FromJSONGraphString",
+    "FunctionListGlobalNamesFunctor",
+    "FunctionRemoveGlobal",
+    "GetFirstStructuralMismatch",
+    "GetGlobalFuncMetadata",
+    "GetRegisteredTypeKeys",
+    "MakeObjectFromPackedArgs",
+    "Map",
+    "MapCount",
+    "MapForwardIterFunctor",
+    "MapGetItem",
+    "MapSize",
+    "ModuleClearImports",
+    "ModuleGetFunction",
+    "ModuleGetFunctionDoc",
+    "ModuleGetFunctionMetadata",
+    "ModuleGetKind",
+    "ModuleGetPropertyMask",
+    "ModuleGetWriteFormats",
+    "ModuleImplementsFunction",
+    "ModuleImportModule",
+    "ModuleInspectSource",
+    "ModuleLoadFromFile",
+    "ModuleWriteToFile",
+    "Shape",
+    "String",
+    "StructuralHash",
+    "SystemLib",
+    "ToJSONGraph",
+    "ToJSONGraphString",
+    # tvm-ffi-stubgen(end)
+]
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 261803c..06a78e6 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -230,6 +230,7 @@ class TypeSchema:
     origin: str
     args: tuple[TypeSchema, ...] = ()
 
+    def __init__(self, origin: str, args: tuple[TypeSchema, ...] = ()) -> 
None: ...
     @staticmethod
     def from_json_obj(obj: dict[str, Any]) -> TypeSchema: ...
     @staticmethod
diff --git a/python/tvm_ffi/stub/analysis.py b/python/tvm_ffi/stub/analysis.py
index 2207397..03dbe36 100644
--- a/python/tvm_ffi/stub/analysis.py
+++ b/python/tvm_ffi/stub/analysis.py
@@ -21,42 +21,21 @@ from __future__ import annotations
 from tvm_ffi.registry import list_global_func_names
 
 from . import consts as C
-from .file_utils import FileInfo
-from .utils import Options
+from .utils import FuncInfo
 
 
-def collect_global_funcs() -> dict[str, list[str]]:
+def collect_global_funcs() -> dict[str, list[FuncInfo]]:
     """Collect global functions from TVM FFI's global registry."""
     # Build global function table only if we are going to process blocks.
-    global_funcs: dict[str, list[str]] = {}
+    global_funcs: dict[str, list[FuncInfo]] = {}
     for name in list_global_func_names():
         try:
-            prefix, suffix = name.rsplit(".", 1)
+            prefix, _ = name.rsplit(".", 1)
         except ValueError:
             print(f"{C.TERM_YELLOW}[Skipped] Invalid name in global function: 
{name}{C.TERM_RESET}")
         else:
-            global_funcs.setdefault(prefix, []).append(suffix)
+            global_funcs.setdefault(prefix, 
[]).append(FuncInfo.from_global_name(name))
     # Ensure stable ordering for deterministic output.
     for k in list(global_funcs.keys()):
-        global_funcs[k].sort()
+        global_funcs[k].sort(key=lambda x: x.schema.name)
     return global_funcs
-
-
-def collect_ty_maps(files: list[FileInfo], opt: Options) -> dict[str, str]:
-    """Collect type maps from the given files."""
-    ty_map: dict[str, str] = C.TY_MAP_DEFAULTS.copy()
-    for file in files:
-        for code in file.code_blocks:
-            if code.kind == "ty-map":
-                try:
-                    lhs, rhs = code.param.split("->")
-                except ValueError as e:
-                    raise ValueError(
-                        f"Invalid ty_map format at line {code.lineno_start}. 
Example: `A.B -> C.D`"
-                    ) from e
-                ty_map[lhs.strip()] = rhs.strip()
-    if opt.verbose:
-        for lhs in sorted(ty_map):
-            rhs = ty_map[lhs]
-            print(f"{C.TERM_CYAN}[TY-MAP] {lhs} -> {rhs}{C.TERM_RESET}")
-    return ty_map
diff --git a/python/tvm_ffi/stub/cli.py b/python/tvm_ffi/stub/cli.py
index 5728e12..85697f5 100644
--- a/python/tvm_ffi/stub/cli.py
+++ b/python/tvm_ffi/stub/cli.py
@@ -28,7 +28,7 @@ from typing import Callable
 
 from . import codegen as G
 from . import consts as C
-from .analysis import collect_global_funcs, collect_ty_maps
+from .analysis import collect_global_funcs
 from .file_utils import FileInfo, collect_files
 from .utils import Options
 
@@ -56,17 +56,37 @@ def __main__() -> int:
     """
     opt = _parse_args()
     dlls = [ctypes.CDLL(lib) for lib in opt.dlls]
-    global_funcs = collect_global_funcs()
     files: list[FileInfo] = collect_files([Path(f) for f in opt.files])
 
     # Stage 1: Process `tvm-ffi-stubgen(ty-map)`
-    ty_map: dict[str, str] = collect_ty_maps(files, opt)
+    ty_map: dict[str, str] = C.TY_MAP_DEFAULTS.copy()
+
+    def _stage_1(file: FileInfo) -> None:
+        for code in file.code_blocks:
+            if code.kind == "ty-map":
+                try:
+                    lhs, rhs = code.param.split("->")
+                except ValueError as e:
+                    raise ValueError(
+                        f"Invalid ty_map format at line {code.lineno_start}. 
Example: `A.B -> C.D`"
+                    ) from e
+                ty_map[lhs.strip()] = rhs.strip()
+
+    for file in files:
+        try:
+            _stage_1(file)
+        except Exception:
+            print(
+                f'{C.TERM_RED}[Failed] File "{file.path}": 
{traceback.format_exc()}{C.TERM_RESET}'
+            )
 
     # Stage 2: Process
     # - `tvm-ffi-stubgen(begin): global/...`
     # - `tvm-ffi-stubgen(begin): object/...`
+    global_funcs = collect_global_funcs()
 
     def _stage_2(file: FileInfo) -> None:
+        all_defined = set()
         if opt.verbose:
             print(f"{C.TERM_CYAN}[File] {file.path}{C.TERM_RESET}")
         ty_used: set[str] = set()
@@ -75,17 +95,26 @@ def __main__() -> int:
         # Stage 2.1. Process `tvm-ffi-stubgen(begin): global/...`
         for code in file.code_blocks:
             if code.kind == "global":
-                G.generate_global_funcs(code, global_funcs, fn_ty_map_fn, opt)
+                funcs = global_funcs.get(code.param, [])
+                for func in funcs:
+                    all_defined.add(func.schema.name)
+                G.generate_global_funcs(code, funcs, fn_ty_map_fn, opt)
         # Stage 2.2. Process `tvm-ffi-stubgen(begin): object/...`
         for code in file.code_blocks:
             if code.kind == "object":
+                type_key = code.param
+                ty_on_file.add(ty_map.get(type_key, type_key))
                 G.generate_object(code, fn_ty_map_fn, opt)
-                ty_on_file.add(ty_map.get(code.param, code.param))
         # Stage 2.3. Add imports for used types.
         for code in file.code_blocks:
             if code.kind == "import":
                 G.generate_imports(code, ty_used - ty_on_file, opt)
                 break  # Only one import block per file is supported for now.
+        # Stage 2.4. Add `__all__` for defined classes and functions.
+        for code in file.code_blocks:
+            if code.kind == "__all__":
+                G.generate_all(code, all_defined | ty_on_file, opt)
+                break  # Only one __all__ block per file is supported for now.
         file.update(show_diff=opt.verbose, dry_run=opt.dry_run)
 
     for file in files:
diff --git a/python/tvm_ffi/stub/codegen.py b/python/tvm_ffi/stub/codegen.py
index be33f16..c15624a 100644
--- a/python/tvm_ffi/stub/codegen.py
+++ b/python/tvm_ffi/stub/codegen.py
@@ -18,77 +18,36 @@
 
 from __future__ import annotations
 
-from io import StringIO
 from typing import Callable
 
-from tvm_ffi.core import TypeSchema, 
_lookup_or_register_type_info_from_type_key
-from tvm_ffi.registry import get_global_func_metadata
-
 from . import consts as C
 from .file_utils import CodeBlock
-from .utils import Options
-
-
-def generate_func_signature(
-    schema: TypeSchema,
-    func_name: str,
-    ty_map: Callable[[str], str],
-    is_member: bool,
-) -> str:
-    """Generate a function signature string from a TypeSchema."""
-    buf = StringIO()
-    buf.write(f"def {func_name}(")
-    if schema.origin != "Callable":
-        raise ValueError(f"Expected Callable type schema, but got: {schema}")
-    if not schema.args:
-        ty_map("Any")
-        buf.write("*args: Any) -> Any: ...")
-        return buf.getvalue()
-    arg_ret = schema.args[0]
-    arg_args = schema.args[1:]
-    for i, arg in enumerate(arg_args):
-        if is_member and i == 0:
-            buf.write("self, ")
-        else:
-            buf.write(f"_{i}: ")
-            buf.write(arg.repr(ty_map))
-            buf.write(", ")
-    if arg_args:
-        buf.write("/")
-    buf.write(") -> ")
-    buf.write(arg_ret.repr(ty_map))
-    buf.write(": ...")
-    return buf.getvalue()
+from .utils import FuncInfo, ObjectInfo, Options
 
 
 def generate_global_funcs(
-    code: CodeBlock,
-    global_funcs: dict[str, list[str]],
-    fn_ty_map: Callable[[str], str],
-    opt: Options,
+    code: CodeBlock, global_funcs: list[FuncInfo], fn_ty_map: Callable[[str], 
str], opt: Options
 ) -> None:
     """Generate function signatures for global functions."""
     assert len(code.lines) >= 2
-    indent = " " * code.indent
-    indent_long = " " * (code.indent + opt.indent)
-    prefix = code.param
+    if not global_funcs:
+        return
     results: list[str] = [
-        generate_func_signature(
-            
TypeSchema.from_json_str(get_global_func_metadata(f"{prefix}.{name}")["type_schema"]),
-            name,
-            ty_map=fn_ty_map,
-            is_member=False,
-        )
-        for name in global_funcs.get(prefix, [])
+        "# fmt: off",
+        "if TYPE_CHECKING:",
+        *[
+            func.gen(
+                fn_ty_map,
+                indent=opt.indent,
+            )
+            for func in global_funcs
+        ],
+        "# fmt: on",
     ]
-    if not results:
-        return
+    indent = " " * code.indent
     code.lines = [
         code.lines[0],
-        f"{indent}# fmt: off",
-        f"{indent}if TYPE_CHECKING:",
-        *[indent_long + sig for sig in results],
-        f"{indent}# fmt: on",
+        *[indent + line for line in results],
         code.lines[-1],
     ]
 
@@ -96,50 +55,30 @@ def generate_global_funcs(
 def generate_object(code: CodeBlock, fn_ty_map: Callable[[str], str], opt: 
Options) -> None:
     """Generate a class definition for an object type."""
     assert len(code.lines) >= 2
-    type_key = code.param
-    type_info = _lookup_or_register_type_info_from_type_key(type_key)
+    info = ObjectInfo.from_type_key(code.param)
+    if info.methods:
+        results = [
+            "# fmt: off",
+            *info.gen_fields(fn_ty_map, indent=0),
+            "if TYPE_CHECKING:",
+            *info.gen_methods(fn_ty_map, indent=opt.indent),
+            "# fmt: on",
+        ]
+    else:
+        results = [
+            "# fmt: off",
+            *info.gen_fields(fn_ty_map, indent=0),
+            "# fmt: on",
+        ]
     indent = " " * code.indent
-    indent_long = " " * (code.indent + opt.indent)
-
-    fields: list[str] = []
-    for field in type_info.fields:
-        fields.append(
-            f"{indent}{field.name}: "
-            + 
TypeSchema.from_json_str(field.metadata["type_schema"]).repr(fn_ty_map)
-        )
-
-    methods: list[str] = []
-    if type_info.methods:
-        methods = [f"{indent}if TYPE_CHECKING:"]
-    for method in type_info.methods:
-        if method.is_static:
-            methods.append(f"{indent_long}@staticmethod")
-        methods.append(
-            indent_long
-            + generate_func_signature(
-                TypeSchema.from_json_str(method.metadata["type_schema"]),
-                {
-                    "__ffi_init__": "__c_ffi_init__",
-                }.get(method.name, method.name),
-                fn_ty_map,
-                is_member=not method.is_static,
-            )
-        )
     code.lines = [
         code.lines[0],
-        f"{indent}# fmt: off",
-        *fields,
-        *methods,
-        f"{indent}# fmt: on",
+        *[indent + line for line in results],
         code.lines[-1],
     ]
 
 
-def generate_imports(
-    code: CodeBlock,
-    ty_used: set[str],
-    opt: Options,
-) -> None:
+def generate_imports(code: CodeBlock, ty_used: set[str], opt: Options) -> None:
     """Generate import statements for the types used in the stub."""
     ty_collected: dict[str, list[str]] = {}
     for ty in ty_used:
@@ -181,3 +120,18 @@ def generate_imports(
             "# fmt: on",
             code.lines[-1],
         ]
+
+
+def generate_all(code: CodeBlock, names: set[str], opt: Options) -> None:
+    """Generate an `__all__` variable for the given names."""
+    assert len(code.lines) >= 2
+    if not names:
+        return
+
+    indent = " " * code.indent
+    names = {f.rsplit(".", 1)[-1] for f in names}
+    code.lines = [
+        code.lines[0],
+        *[f'{indent}"{name}",' for name in sorted(names)],
+        code.lines[-1],
+    ]
diff --git a/python/tvm_ffi/stub/consts.py b/python/tvm_ffi/stub/consts.py
index 6443c95..6922254 100644
--- a/python/tvm_ffi/stub/consts.py
+++ b/python/tvm_ffi/stub/consts.py
@@ -54,3 +54,7 @@ MOD_MAP = {
     "testing": "tvm_ffi.testing",
     "ffi": "tvm_ffi",
 }
+
+FN_NAME_MAP = {
+    "__ffi_init__": "__c_ffi_init__",
+}
diff --git a/python/tvm_ffi/stub/file_utils.py 
b/python/tvm_ffi/stub/file_utils.py
index 6f95025..f100c55 100644
--- a/python/tvm_ffi/stub/file_utils.py
+++ b/python/tvm_ffi/stub/file_utils.py
@@ -31,7 +31,7 @@ from . import consts as C
 class CodeBlock:
     """A block of code to be generated in a stub file."""
 
-    kind: Literal["global", "object", "ty-map", "import", None]
+    kind: Literal["global", "object", "ty-map", "import", "__all__", None]
     param: str
     lineno_start: int
     lineno_end: int | None
@@ -39,7 +39,7 @@ class CodeBlock:
 
     def __post_init__(self) -> None:
         """Validate the code block after initialization."""
-        assert self.kind in {"global", "object", "ty-map", "import", None}
+        assert self.kind in {"global", "object", "ty-map", "import", 
"__all__", None}
 
     @property
     def indent(self) -> int:
@@ -74,6 +74,9 @@ class CodeBlock:
         elif stub.startswith("import"):
             kind = "import"
             param = ""
+        elif stub == "__all__":
+            kind = "__all__"
+            param = ""
         else:
             raise ValueError(f"Unknown stub type `{stub}` at line {lineo}")
         return CodeBlock(
@@ -178,7 +181,7 @@ def collect_files(paths: list[Path]) -> list[FileInfo]:
 
     def _on_error(e: Exception) -> None:
         print(
-            f'{C.TERM_RED}[Failed] File 
"{file}"\n{traceback.format_exc()}{C.TERM_RESET}',
+            f"{C.TERM_RED}[Error]\n{traceback.format_exc()}{C.TERM_RESET}",
             end="",
             flush=True,
         )
diff --git a/python/tvm_ffi/stub/utils.py b/python/tvm_ffi/stub/utils.py
index 7f461c4..e8beb41 100644
--- a/python/tvm_ffi/stub/utils.py
+++ b/python/tvm_ffi/stub/utils.py
@@ -19,6 +19,12 @@
 from __future__ import annotations
 
 import dataclasses
+from io import StringIO
+from typing import Callable
+
+from tvm_ffi.core import TypeSchema
+
+from . import consts as C
 
 
 @dataclasses.dataclass
@@ -30,3 +36,116 @@ class Options:
     files: list[str] = dataclasses.field(default_factory=list)
     verbose: bool = False
     dry_run: bool = False
+
+
[email protected](init=False)
+class NamedTypeSchema(TypeSchema):
+    """A type schema with an associated name."""
+
+    name: str
+
+    def __init__(self, name: str, schema: TypeSchema) -> None:
+        """Initialize a `NamedTypeSchema` with the given name and type 
schema."""
+        super().__init__(origin=schema.origin, args=schema.args)
+        self.name = name
+
+
[email protected]
+class FuncInfo:
+    """Information of a function."""
+
+    schema: NamedTypeSchema
+    is_member: bool
+
+    @staticmethod
+    def from_global_name(name: str) -> FuncInfo:
+        """Construct a `FuncInfo` from a string name of this global 
function."""
+        from tvm_ffi.registry import get_global_func_metadata  # noqa: PLC0415
+
+        return FuncInfo(
+            schema=NamedTypeSchema(
+                name=name,
+                
schema=TypeSchema.from_json_str(get_global_func_metadata(name)["type_schema"]),
+            ),
+            is_member=False,
+        )
+
+    def gen(self, ty_map: Callable[[str], str], indent: int) -> str:
+        """Generate a function signature string for this function."""
+        try:
+            _, func_name = self.schema.name.rsplit(".", 1)
+        except ValueError:
+            func_name = self.schema.name
+        buf = StringIO()
+        buf.write(" " * indent)
+        buf.write(f"def {func_name}(")
+        if self.schema.origin != "Callable":
+            raise ValueError(f"Expected Callable type schema, but got: 
{self.schema}")
+        if not self.schema.args:
+            ty_map("Any")
+            buf.write("*args: Any) -> Any: ...")
+            return buf.getvalue()
+        arg_ret = self.schema.args[0]
+        arg_args = self.schema.args[1:]
+        for i, arg in enumerate(arg_args):
+            if self.is_member and i == 0:
+                buf.write("self, ")
+            else:
+                buf.write(f"_{i}: ")
+                buf.write(arg.repr(ty_map))
+                buf.write(", ")
+        if arg_args:
+            buf.write("/")
+        buf.write(") -> ")
+        buf.write(arg_ret.repr(ty_map))
+        buf.write(": ...")
+        return buf.getvalue()
+
+
[email protected]
+class ObjectInfo:
+    """Information of an object type, including its fields and methods."""
+
+    fields: list[NamedTypeSchema]
+    methods: list[FuncInfo]
+
+    @staticmethod
+    def from_type_key(type_key: str) -> ObjectInfo:
+        """Construct an `ObjectInfo` from a type key."""
+        from tvm_ffi.core import _lookup_or_register_type_info_from_type_key  
# noqa: PLC0415
+
+        type_info = _lookup_or_register_type_info_from_type_key(type_key)
+        return ObjectInfo(
+            fields=[
+                NamedTypeSchema(
+                    name=field.name,
+                    
schema=TypeSchema.from_json_str(field.metadata["type_schema"]),
+                )
+                for field in type_info.fields
+            ],
+            methods=[
+                FuncInfo(
+                    schema=NamedTypeSchema(
+                        name=C.FN_NAME_MAP.get(method.name, method.name),
+                        
schema=TypeSchema.from_json_str(method.metadata["type_schema"]),
+                    ),
+                    is_member=not method.is_static,
+                )
+                for method in type_info.methods
+            ],
+        )
+
+    def gen_fields(self, ty_map: Callable[[str], str], indent: int) -> 
list[str]:
+        """Generate field definitions for this object."""
+        indent_str = " " * indent
+        return [f"{indent_str}{field.name}: {field.repr(ty_map)}" for field in 
self.fields]
+
+    def gen_methods(self, ty_map: Callable[[str], str], indent: int) -> 
list[str]:
+        """Generate method definitions for this object."""
+        indent_str = " " * indent
+        ret = []
+        for method in self.methods:
+            if not method.is_member:
+                ret.append(f"{indent_str}@staticmethod")
+            ret.append(method.gen(ty_map, indent))
+        return ret
diff --git a/tests/python/test_stubgen.py b/tests/python/test_stubgen.py
new file mode 100644
index 0000000..ef53143
--- /dev/null
+++ b/tests/python/test_stubgen.py
@@ -0,0 +1,361 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from pathlib import Path
+
+import pytest
+from tvm_ffi.core import TypeSchema
+from tvm_ffi.stub import consts as C
+from tvm_ffi.stub.codegen import (
+    generate_all,
+    generate_global_funcs,
+    generate_imports,
+    generate_object,
+)
+from tvm_ffi.stub.file_utils import CodeBlock, FileInfo
+from tvm_ffi.stub.utils import FuncInfo, NamedTypeSchema, ObjectInfo, Options
+
+
+def _identity_ty_map(name: str) -> str:
+    return name
+
+
+def test_codeblock_from_begin_line_variants() -> None:
+    cases = [
+        (f"{C.STUB_BEGIN} global/example", "global", "example"),
+        (f"{C.STUB_BEGIN} object/testing.TestObjectBase", "object", 
"testing.TestObjectBase"),
+        (f"{C.STUB_BEGIN} ty-map/custom", "ty-map", "custom"),
+        (f"{C.STUB_BEGIN} import", "import", ""),
+    ]
+    for lineno, (line, kind, param) in enumerate(cases, start=1):
+        block = CodeBlock.from_begin_line(lineno, line)
+        assert block.kind == kind
+        assert block.param == param
+        assert block.lineno_start == lineno
+        assert block.lineno_end is None
+        assert block.lines == []
+
+
+def test_codeblock_from_begin_line_ty_map_and_unknown() -> None:
+    line = f"{C.STUB_TY_MAP} custom -> mapped"
+    block = CodeBlock.from_begin_line(5, line)
+    assert block.kind == "ty-map"
+    assert block.param == "custom -> mapped"
+    assert block.lineno_start == 5
+    assert block.lineno_end == 5
+
+    with pytest.raises(ValueError):
+        CodeBlock.from_begin_line(1, f"{C.STUB_BEGIN} unsupported/kind")
+
+
+def test_fileinfo_from_file_skip_and_missing_markers(tmp_path: Path) -> None:
+    skip = tmp_path / "skip.py"
+    skip.write_text(f"print('hi')\n{C.STUB_SKIP_FILE}\n", encoding="utf-8")
+    assert FileInfo.from_file(skip) is None
+
+    plain = tmp_path / "plain.py"
+    plain.write_text("print('plain')\n", encoding="utf-8")
+    assert FileInfo.from_file(plain) is None
+
+
+def test_fileinfo_from_file_parses_blocks(tmp_path: Path) -> None:
+    content = "\n".join(
+        [
+            "first = 1",
+            f"{C.STUB_BEGIN} global/demo.func",
+            "in_stub = True",
+            C.STUB_END,
+            f"{C.STUB_TY_MAP} x -> y",
+        ]
+    )
+    path = tmp_path / "demo.py"
+    path.write_text(content, encoding="utf-8")
+
+    info = FileInfo.from_file(path)
+    assert info is not None
+    assert info.path == path.resolve()
+    assert len(info.code_blocks) == 3
+
+    first, stub, ty_map = info.code_blocks
+    assert first.kind is None and first.lines == ["first = 1"]
+
+    assert stub.kind == "global"
+    assert stub.param == "demo.func"
+    assert stub.lineno_start == 2
+    assert stub.lineno_end == 4
+    assert stub.lines == [
+        f"{C.STUB_BEGIN} global/demo.func",
+        "in_stub = True",
+        C.STUB_END,
+    ]
+
+    assert ty_map.kind == "ty-map"
+    assert ty_map.param == "x -> y"
+    assert ty_map.lineno_start == ty_map.lineno_end == 5
+    assert ty_map.lines == [f"{C.STUB_TY_MAP} x -> y"]
+
+
+def test_fileinfo_from_file_error_paths(tmp_path: Path) -> None:
+    nested = tmp_path / "nested.py"
+    nested.write_text(
+        "\n".join(
+            [
+                f"{C.STUB_BEGIN} global/outer",
+                f"{C.STUB_BEGIN} global/inner",
+            ]
+        ),
+        encoding="utf-8",
+    )
+    with pytest.raises(ValueError, match="Nested stub not permitted"):
+        FileInfo.from_file(nested)
+
+    unmatched_end = tmp_path / "unmatched.py"
+    unmatched_end.write_text(C.STUB_END + "\n", encoding="utf-8")
+    with pytest.raises(ValueError, match="Unmatched"):
+        FileInfo.from_file(unmatched_end)
+
+    unclosed = tmp_path / "unclosed.py"
+    unclosed.write_text(f"{C.STUB_BEGIN} global/method\n", encoding="utf-8")
+    with pytest.raises(ValueError, match="Unclosed stub block"):
+        FileInfo.from_file(unclosed)
+
+
+def test_funcinfo_gen_variants() -> None:
+    called: list[str] = []
+
+    def ty_map(name: str) -> str:
+        called.append(name)
+        return name
+
+    schema_no_args = NamedTypeSchema("demo.no_args", TypeSchema("Callable", 
()))
+    func = FuncInfo(schema=schema_no_args, is_member=False)
+    assert func.gen(ty_map, indent=2) == "  def no_args(*args: Any) -> Any: 
..."
+    assert called == ["Any"]
+
+    schema_member = NamedTypeSchema(
+        "pkg.Class.method",
+        TypeSchema(
+            "Callable",
+            (
+                TypeSchema("str"),
+                TypeSchema("int"),
+                TypeSchema("float"),
+            ),
+        ),
+    )
+    member_func = FuncInfo(schema=schema_member, is_member=True)
+    assert (
+        member_func.gen(_identity_ty_map, indent=0) == "def method(self, _1: 
float, /) -> str: ..."
+    )
+
+    schema_bad = NamedTypeSchema("bad", TypeSchema("int"))
+    with pytest.raises(ValueError):
+        FuncInfo(schema=schema_bad, is_member=False).gen(_identity_ty_map, 
indent=0)
+
+
+def test_objectinfo_gen_fields_and_methods() -> None:
+    ty_calls: list[str] = []
+
+    def ty_map(name: str) -> str:
+        ty_calls.append(name)
+        return {"list": "Sequence", "dict": "Mapping"}.get(name, name)
+
+    info = ObjectInfo(
+        fields=[
+            NamedTypeSchema("field_a", TypeSchema("list", 
(TypeSchema("int"),))),
+            NamedTypeSchema(
+                "field_b", TypeSchema("dict", (TypeSchema("str"), 
TypeSchema("float")))
+            ),
+        ],
+        methods=[
+            FuncInfo(
+                schema=NamedTypeSchema("demo.static", TypeSchema("Callable", 
(TypeSchema("int"),))),
+                is_member=False,
+            ),
+            FuncInfo(
+                schema=NamedTypeSchema(
+                    "demo.member",
+                    TypeSchema("Callable", (TypeSchema("str"), 
TypeSchema("bytes"))),
+                ),
+                is_member=True,
+            ),
+        ],
+    )
+
+    assert info.gen_fields(ty_map, indent=2) == [
+        "  field_a: Sequence[int]",
+        "  field_b: Mapping[str, float]",
+    ]
+    assert ty_calls.count("list") == 1 and ty_calls.count("dict") == 1
+
+    methods = info.gen_methods(_identity_ty_map, indent=2)
+    assert methods == [
+        "  @staticmethod",
+        "  def static() -> int: ...",
+        "  def member(self, /) -> str: ...",
+    ]
+
+
+def test_generate_global_funcs_updates_block() -> None:
+    code = CodeBlock(
+        kind="global",
+        param="testing",
+        lineno_start=1,
+        lineno_end=2,
+        lines=[f"{C.STUB_BEGIN} global/testing", C.STUB_END],
+    )
+    funcs = [
+        FuncInfo(
+            schema=NamedTypeSchema(
+                "testing.add_one",
+                TypeSchema("Callable", (TypeSchema("int"), TypeSchema("int"))),
+            ),
+            is_member=False,
+        )
+    ]
+    opts = Options(indent=2)
+    generate_global_funcs(code, funcs, _identity_ty_map, opts)
+    assert code.lines == [
+        f"{C.STUB_BEGIN} global/testing",
+        "# fmt: off",
+        "if TYPE_CHECKING:",
+        "  def add_one(_0: int, /) -> int: ...",
+        "# fmt: on",
+        C.STUB_END,
+    ]
+
+
+def test_generate_global_funcs_noop_on_empty_list() -> None:
+    code = CodeBlock(
+        kind="global",
+        param="empty",
+        lineno_start=1,
+        lineno_end=2,
+        lines=[f"{C.STUB_BEGIN} global/empty", C.STUB_END],
+    )
+    generate_global_funcs(code, [], _identity_ty_map, Options())
+    assert code.lines == [f"{C.STUB_BEGIN} global/empty", C.STUB_END]
+
+
+def test_generate_object_fields_only_block() -> None:
+    code = CodeBlock(
+        kind="object",
+        param="testing.TestObjectDerived",
+        lineno_start=1,
+        lineno_end=2,
+        lines=[f"{C.STUB_BEGIN} object/testing.TestObjectDerived", C.STUB_END],
+    )
+    opts = Options(indent=4)
+    generate_object(code, _identity_ty_map, opts)
+
+    info = ObjectInfo.from_type_key("testing.TestObjectDerived")
+    expected = [
+        f"{C.STUB_BEGIN} object/testing.TestObjectDerived",
+        " " * code.indent + "# fmt: off",
+        *[(" " * code.indent) + line for line in 
info.gen_fields(_identity_ty_map, indent=0)],
+        " " * code.indent + "# fmt: on",
+        C.STUB_END,
+    ]
+    assert code.lines == expected
+
+
+def test_generate_object_with_methods() -> None:
+    code = CodeBlock(
+        kind="object",
+        param="testing.TestIntPair",
+        lineno_start=1,
+        lineno_end=2,
+        lines=[f"{C.STUB_BEGIN} object/testing.TestIntPair", C.STUB_END],
+    )
+    opts = Options(indent=4)
+    generate_object(code, _identity_ty_map, opts)
+
+    assert code.lines[0] == f"{C.STUB_BEGIN} object/testing.TestIntPair"
+    assert code.lines[-1] == C.STUB_END
+    assert "# fmt: off" in code.lines[1]
+    assert any("if TYPE_CHECKING:" in line for line in code.lines)
+    method_lines = [
+        line for line in code.lines if "def __c_ffi_init__" in line or "def 
sum" in line
+    ]
+    assert any(line.strip().startswith("def __c_ffi_init__") for line in 
method_lines)
+    assert any(line.strip().startswith("def sum") for line in method_lines)
+
+
+def test_generate_imports_groups_modules() -> None:
+    code = CodeBlock(
+        kind="import",
+        param="",
+        lineno_start=1,
+        lineno_end=2,
+        lines=[f"{C.STUB_BEGIN} import", C.STUB_END],
+    )
+    ty_used = {
+        "typing.Any",
+        "tvm_ffi.Tensor",
+        "testing.TestObjectBase",
+        "custom.mod.Type",
+    }
+    opts = Options(indent=4)
+    generate_imports(code, ty_used, opts)
+
+    expected_prefix = [
+        f"{C.STUB_BEGIN} import",
+        "# fmt: off",
+        "# isort: off",
+        "from __future__ import annotations",
+        "from typing import Any, TYPE_CHECKING",
+        "if TYPE_CHECKING:",
+    ]
+    assert code.lines[: len(expected_prefix)] == expected_prefix
+    assert "    from tvm_ffi.testing import TestObjectBase" in code.lines
+    assert "    from tvm_ffi import Tensor" in code.lines
+    assert "    from custom.mod import Type" in code.lines
+    assert code.lines[-2:] == ["# fmt: on", C.STUB_END]
+
+
+def test_generate_all_builds_sorted_and_deduped_list() -> None:
+    code = CodeBlock(
+        kind="global",
+        param="all",
+        lineno_start=1,
+        lineno_end=2,
+        lines=["    " + C.STUB_BEGIN + " global/all", C.STUB_END],
+    )
+    generate_all(
+        code,
+        names={"tvm_ffi.foo", "bar", "pkg.baz", "bar"},  # duplicates stripped
+        opt=Options(indent=2),
+    )
+    assert code.lines == [
+        "    " + C.STUB_BEGIN + " global/all",
+        '    "bar",',
+        '    "baz",',
+        '    "foo",',
+        C.STUB_END,
+    ]
+
+
+def test_generate_all_noop_on_empty_names() -> None:
+    code = CodeBlock(
+        kind="global",
+        param="all-empty",
+        lineno_start=1,
+        lineno_end=2,
+        lines=[C.STUB_BEGIN + " global/all-empty", C.STUB_END],
+    )
+    before = list(code.lines)
+    generate_all(code, names=set(), opt=Options())
+    assert code.lines == before

Reply via email to