This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new b58c2e3 feat(stubgen): Package generation with `--init-*` flags (#295)
b58c2e3 is described below
commit b58c2e3d7deadbd60c7480f5c84633966260bc9a
Author: Junru Shao <[email protected]>
AuthorDate: Thu Dec 18 08:39:10 2025 -0800
feat(stubgen): Package generation with `--init-*` flags (#295)
This commit adds explicit `--init-*` flags so that packages can be
bootstrapped cleanly from a single command.
**Usage.** for the my-ffi-extension example):
```
tvm-ffi-stubgen examples/packaging/python \
--dlls examples/packaging/build/libmy_ffi_extension.dylib \
--init-pypkg my-ffi-extension \
--init-lib my_ffi_extension \
--init-prefix "my_ffi_extension."
```
What each flag means:
- `PATH`: the files/directories to scan; the directory of the first
entry becomes the root where `_ffi_api.py`/`__init__.py` stubs are
generated.
- `--dlls`: preload the built shared library so global function/type
metadata is available (here the dylib from the packaging example build).
- `--init-pypkg`: the published Python package name (wheel/sdist name),
used in the loader string.
- `--init-lib`: the CMake target/basename of the shared library
(`lib<init-lib>.so`/`.dylib`/`.dll`).
- `--init-prefix`: registry prefix to include when generating
globals/objects (e.g. `my_ffi_extension.`).
**Behavior**.
- The tool creates missing `_ffi_api.py` and `__init__.py` under the
derived path/prefix.
- It scans existing stub blocks, honors ty-map and imports, and adds an
`__all__` section without duplication on reruns.
- Re-running the command is idempotent: previously generated sections
are detected and not appended again.
**Example**. See `my_ffi_extension/__init__.py` and
`my_ffi_extension/_ffi_api.py`.
**How to use this tool**
```
>>> tvm-ffi-stubgen --help
usage: tvm-ffi-stubgen [-h] [--imports IMPORTS] [--dlls LIBS] [--init-pypkg
INIT_PYPKG] [--init-lib INIT_LIB] [--init-prefix INIT_PREFIX] [--indent INDENT]
[--verbose] [--dry-run] [PATH ...]
Generate type stubs for TVM FFI extensions.
In `--init-*` mode, it generates missing `_ffi_api.py` and `__init__.py`
files, based on the registered global functions and object types in the loaded
libraries.
In normal mode, it processes the given files/directories in-place,
generating type stubs inside special `tvm-ffi-stubgen` blocks. Scroo down for
more details.
positional arguments:
PATH Files or directories to process. Directories are
scanned recursively; only .py and .pyi files are modified. Use tvm-ffi-stubgen
markers to select where stubs are generated. (default: None)
options:
-h, --help show this help message and exit
--imports IMPORTS Additional imports to load before generation,
separated by ';' (e.g. 'pkgA;pkgB.submodule'). (default: )
--dlls LIBS Shared libraries to preload before generation (e.g.
TVM runtime or your extension), separated by ';'. This ensures global function
and object metadata is available. Platform-specific suffixes like
.so/.dylib/.dll are supported. (default: )
--init-pypkg INIT_PYPKG
Python package name to generate stubs for (e.g.
apache-tvm-ffi). Required together with --init-lib, --init-path, and
--init-prefix. (default: )
--init-lib INIT_LIB CMake target that produces the shared library to
load for stub generation (e.g. tvm_ffi_shared). Required together with
--init-pypkg and --init-prefix. (default: )
--init-prefix INIT_PREFIX
Global function/object prefix to include when
generating stubs (e.g. tvm_ffi.). Required together with --init-pypkg and
--init-lib. (default: )
--indent INDENT Extra spaces added inside each generated block,
relative to the indentation of the corresponding '# tvm-ffi-stubgen(begin):'
line. (default: 4)
--verbose Print a unified diff of changes to each file. This
is useful for debugging or previewing changes before applying them. (default:
False)
--dry-run Don't write changes to files. This is useful for
previewing changes without modifying any files. (default: False)
========
Examples
========
# Single file
tvm-ffi-stubgen python/tvm_ffi/_ffi_api.py
# Recursively scan directories
tvm-ffi-stubgen python/tvm_ffi examples/packaging/python/my_ffi_extension
# Preload extension libraries
tvm-ffi-stubgen --dlls build/libmy_ext.so;build/libmy_2nd_ext.so
my_pkg/_ffi_api.py
# Package-level init (my-ffi-extension)
tvm-ffi-stubgen examples/packaging/python \
--dlls examples/packaging/build/libmy_ffi_extension.dylib \
--init-pypkg my-ffi-extension \
--init-lib my_ffi_extension \
--init-prefix "my_ffi_extension."
=====================
Syntax of stub blocks
=====================
Global functions
~~~~~~~~~~~~~~~~
```
# tvm-ffi-stubgen(begin): global/<registry-prefix>@<import-from
(default: tvm_ffi)>
# tvm-ffi-stubgen(end)
```
Generates TYPE_CHECKING-only stubs for functions in the global registry
under the prefix.
Example:
```
# tvm-ffi-stubgen(begin): global/[email protected]
# fmt: off
_FFI_INIT_FUNC("ffi", __name__)
if TYPE_CHECKING:
def Array(*args: Any) -> Any: ...
def ArrayGetItem(_0: Sequence[Any], _1: int, /) -> Any: ...
def ArraySize(_0: Sequence[Any], /) -> int: ...
def Bytes(_0: bytes, /) -> bytes: ...
...
def StructuralHash(_0: Any, _1: bool, _2: bool, /) -> int: ...
def SystemLib(*args: Any) -> Any: ...
def ToJSONGraph(_0: Any, _1: Any, /) -> Any: ...
def ToJSONGraphString(_0: Any, _1: Any, /) -> str: ...
# fmt: on
# tvm-ffi-stubgen(end)
```
Objects
~~~~~~~
```
# tvm-ffi-stubgen(begin): object/<type_key>
# tvm-ffi-stubgen(end)
```
Generates fields/methods for a class defined using TVM-FFI Object APIs.
Example:
```
@register_object("ffi.reflection.AccessPath")
class AccessPath(tvm_ffi.Object):
# tvm-ffi-stubgen(begin): object/ffi.reflection.AccessPath
# fmt: off
parent: Object | None
step: AccessStep | None
depth: int
if TYPE_CHECKING:
@staticmethod
def _root() -> AccessPath: ...
def _extend(self, _1: AccessStep, /) -> AccessPath: ...
def _attr(self, _1: str, /) -> AccessPath: ...
def _array_item(self, _1: int, /) -> AccessPath: ...
def _map_item(self, _1: Any, /) -> AccessPath: ...
def _attr_missing(self, _1: str, /) -> AccessPath: ...
def _array_item_missing(self, _1: int, /) -> AccessPath: ...
def _map_item_missing(self, _1: Any, /) -> AccessPath: ...
def _is_prefix_of(self, _1: AccessPath, /) -> bool: ...
def _to_steps(self, /) -> Sequence[AccessStep]: ...
def _path_equal(self, _1: AccessPath, /) -> bool: ...
# fmt: on
# tvm-ffi-stubgen(end)
```
Import section
~~~~~~~~~~~~~~
```
# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
from ..registry import init_ffi_api as _FFI_INIT_FUNC
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from tvm_ffi import Device, Object, Tensor, dtype
from tvm_ffi.testing import TestIntPair
from typing import Any, Callable
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
```
Auto-populates imports used by generated stubs.
Export
~~~~~~
```
# tvm-ffi-stubgen(begin): export/_ffi_api
# fmt: off
# isort: off
from ._ffi_api import * # noqa: F403
from ._ffi_api import __all__ as _ffi_api__all__
if "__all__" not in globals():
__all__ = []
__all__.extend(_ffi_api__all__)
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
```
Re-exports a generated submodule's __all__ into the parent.
__all__
~~~~~~~
```
__all__ = [
# tvm-ffi-stubgen(begin): __all__
"LIB",
"IntPair",
"raise_error",
# tvm-ffi-stubgen(end)
]
```
Populates __all__ with generated classes/functions and LIB (if present).
Type map
~~~~~~~~
```
# tvm-ffi-stubgen(ty-map): <type_key> -> <python_type>
```
Maps runtime type keys to Python types used in generation.
Example:
```
# tvm-ffi-stubgen(ty-map): ffi.reflection.AccessStep ->
ffi.access_path.AccessStep
```
Import object
~~~~~~~~~~~~~
```
# tvm-ffi-stubgen(import-object): <from>; <type_checking_only>; <alias>
```
Injects a custom import into generated code, optionally TYPE_CHECKING-only.
Example:
```
# tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object
```
Skip file
~~~~~~~~~
```
# tvm-ffi-stubgen(skip-file)
```
Prevents stubgen from modifying the file.
```
---
.../packaging/python/my_ffi_extension/__init__.py | 65 +--
.../packaging/python/my_ffi_extension/_ffi_api.py | 57 ++-
examples/packaging/run_example.py | 2 +-
python/tvm_ffi/_ffi_api.py | 13 +-
python/tvm_ffi/_tensor.py | 5 +
python/tvm_ffi/access_path.py | 5 +-
python/tvm_ffi/container.py | 10 +
python/tvm_ffi/module.py | 5 +-
python/tvm_ffi/stub/analysis.py | 41 --
python/tvm_ffi/stub/cli.py | 514 +++++++++++++++++----
python/tvm_ffi/stub/codegen.py | 248 ++++++++--
python/tvm_ffi/stub/consts.py | 74 ++-
python/tvm_ffi/stub/file_utils.py | 95 +++-
python/tvm_ffi/stub/lib_state.py | 120 +++++
python/tvm_ffi/stub/utils.py | 114 ++++-
.../base.py => python/tvm_ffi/testing/__init__.py | 22 +-
python/tvm_ffi/testing/_ffi_api.py | 132 ++++++
python/tvm_ffi/{ => testing}/testing.py | 13 +-
tests/python/test_stubgen.py | 315 +++++++++++--
19 files changed, 1494 insertions(+), 356 deletions(-)
diff --git a/examples/packaging/python/my_ffi_extension/__init__.py
b/examples/packaging/python/my_ffi_extension/__init__.py
index 292a03b..f114e8b 100644
--- a/examples/packaging/python/my_ffi_extension/__init__.py
+++ b/examples/packaging/python/my_ffi_extension/__init__.py
@@ -13,55 +13,16 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations.
-# order matters here so we need to skip isort here
-# isort: skip_file
-"""Public Python API for the example tvm-ffi extension package."""
-
-from typing import Any, TYPE_CHECKING
-
-import tvm_ffi
-
-from .base import _LIB
-from . import _ffi_api
-
-
-@tvm_ffi.register_object("my_ffi_extension.IntPair")
-class IntPair(tvm_ffi.Object):
- """IntPair object."""
-
- def __init__(self, a: int, b: int) -> None:
- """Construct the object."""
- # __ffi_init__ call into the refl::init<> registered
- # in the static initialization block of the extension library
- self.__ffi_init__(a, b)
-
-
-def add_one(x: Any, y: Any) -> None:
- """Add one to the input tensor.
-
- Parameters
- ----------
- x : Tensor
- The input tensor.
- y : Tensor
- The output tensor.
-
- """
- return _LIB.add_one(x, y)
-
-
-def raise_error(msg: str) -> None:
- """Raise an error with the given message.
-
- Parameters
- ----------
- msg : str
- The message to raise the error with.
-
- Raises
- ------
- RuntimeError
- The error raised by the function.
-
- """
- return _ffi_api.raise_error(msg)
+"""Package my_ffi_extension."""
+
+# tvm-ffi-stubgen(begin): export/_ffi_api
+# fmt: off
+# isort: off
+from ._ffi_api import * # noqa: F403
+from ._ffi_api import __all__ as _ffi_api__all__
+if "__all__" not in globals():
+ __all__ = []
+__all__.extend(_ffi_api__all__)
+# isort: on
+# fmt: on
+# tvm-ffi-stubgen(end)
diff --git a/examples/packaging/python/my_ffi_extension/_ffi_api.py
b/examples/packaging/python/my_ffi_extension/_ffi_api.py
index f57eb2a..887b991 100644
--- a/examples/packaging/python/my_ffi_extension/_ffi_api.py
+++ b/examples/packaging/python/my_ffi_extension/_ffi_api.py
@@ -13,22 +13,53 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations.
+"""FFI API bindings for my_ffi_extension."""
+# tvm-ffi-stubgen(begin): import-section
+# fmt: off
+# isort: off
+from __future__ import annotations
+from tvm_ffi import Object as _ffi_Object, init_ffi_api as _FFI_INIT_FUNC,
register_object as _FFI_REG_OBJ
+from tvm_ffi.libinfo import load_lib_module as _FFI_LOAD_LIB
from typing import TYPE_CHECKING
-
-import tvm_ffi
-
-# make sure lib is loaded first
-from .base import _LIB # noqa: F401
-
-# this is a short cut to register all the global functions
-# prefixed by `my_ffi_extension.` to this module
-tvm_ffi.init_ffi_api("my_ffi_extension", __name__)
-
-
+if TYPE_CHECKING:
+ from tvm_ffi import Object
+# isort: on
+# fmt: on
+# tvm-ffi-stubgen(end)
+# tvm-ffi-stubgen(import-object):
tvm_ffi.libinfo.load_lib_module;False;_FFI_LOAD_LIB
+LIB = _FFI_LOAD_LIB("my-ffi-extension", "my_ffi_extension")
# tvm-ffi-stubgen(begin): global/my_ffi_extension
+# fmt: off
+_FFI_INIT_FUNC("my_ffi_extension", __name__)
if TYPE_CHECKING:
- # fmt: off
def raise_error(_0: str, /) -> None: ...
- # fmt: on
+# fmt: on
# tvm-ffi-stubgen(end)
+# tvm-ffi-stubgen(import-object): tvm_ffi.register_object;False;_FFI_REG_OBJ
+# tvm-ffi-stubgen(import-object): ffi.Object;False;_ffi_Object
+@_FFI_REG_OBJ("my_ffi_extension.IntPair")
+class IntPair(_ffi_Object):
+ """FFI binding for `my_ffi_extension.IntPair`."""
+
+ # tvm-ffi-stubgen(begin): object/my_ffi_extension.IntPair
+ # fmt: off
+ a: int
+ b: int
+ if TYPE_CHECKING:
+ @staticmethod
+ def __c_ffi_init__(_0: int, _1: int, /) -> Object: ...
+ @staticmethod
+ def static_get_second(_0: IntPair, /) -> int: ...
+ def get_first(self, /) -> int: ...
+ # fmt: on
+ # tvm-ffi-stubgen(end)
+
+
+__all__ = [
+ # tvm-ffi-stubgen(begin): __all__
+ "LIB",
+ "IntPair",
+ "raise_error",
+ # tvm-ffi-stubgen(end)
+]
diff --git a/examples/packaging/run_example.py
b/examples/packaging/run_example.py
index f2c79f4..3f3e636 100644
--- a/examples/packaging/run_example.py
+++ b/examples/packaging/run_example.py
@@ -26,7 +26,7 @@ def run_add_one() -> None:
"""Invoke add_one from the extension and print the result."""
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
y = torch.empty_like(x)
- my_ffi_extension.add_one(x, y)
+ my_ffi_extension.LIB.add_one(x, y)
print(y)
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index f9850a5..49fbf68 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -16,23 +16,24 @@
# under the License.
"""FFI API."""
-# tvm-ffi-stubgen(begin): import
+# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
-from typing import Any, Callable, TYPE_CHECKING
+from .registry import init_ffi_api as _FFI_INIT_FUNC
+from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from tvm_ffi import Module, Object
from tvm_ffi.access_path import AccessPath
+ from typing import Any, Callable
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
-from . import registry
-
-# tvm-ffi-stubgen(begin): global/ffi
+# tvm-ffi-stubgen(begin): global/[email protected]
# fmt: off
+_FFI_INIT_FUNC("ffi", __name__)
if TYPE_CHECKING:
def Array(*args: Any) -> Any: ...
def ArrayGetItem(_0: Sequence[Any], _1: int, /) -> Any: ...
@@ -76,8 +77,6 @@ if TYPE_CHECKING:
# fmt: on
# tvm-ffi-stubgen(end)
-registry.init_ffi_api("ffi", __name__)
-
__all__ = [
# tvm-ffi-stubgen(begin): __all__
diff --git a/python/tvm_ffi/_tensor.py b/python/tvm_ffi/_tensor.py
index bb60882..b9a4954 100644
--- a/python/tvm_ffi/_tensor.py
+++ b/python/tvm_ffi/_tensor.py
@@ -57,6 +57,11 @@ class Shape(tuple, PyNativeObject):
_tvm_ffi_cached_object: Any
+ # tvm-ffi-stubgen(begin): object/ffi.Shape
+ # fmt: off
+ # fmt: on
+ # tvm-ffi-stubgen(end)
+
def __new__(cls, content: tuple[int, ...]) -> Shape:
if any(not isinstance(x, Integral) for x in content):
raise ValueError("Shape must be a tuple of integers")
diff --git a/python/tvm_ffi/access_path.py b/python/tvm_ffi/access_path.py
index 89d1b3c..6980e49 100644
--- a/python/tvm_ffi/access_path.py
+++ b/python/tvm_ffi/access_path.py
@@ -17,14 +17,15 @@
# pylint: disable=invalid-name
"""Access path classes."""
-# tvm-ffi-stubgen(begin): import
+# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
-from typing import Any, TYPE_CHECKING
+from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Sequence
from tvm_ffi import Object
+ from typing import Any
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index dfa0d22..bda4c57 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -150,6 +150,11 @@ class Array(core.Object, Sequence[T]):
"""
+ # tvm-ffi-stubgen(begin): object/ffi.Array
+ # fmt: off
+ # fmt: on
+ # tvm-ffi-stubgen(end)
+
def __init__(self, input_list: Iterable[T]) -> None:
"""Construct an Array from a Python sequence."""
self.__init_handle_by_constructor__(_ffi_api.Array, *input_list)
@@ -291,6 +296,11 @@ class Map(core.Object, Mapping[K, V]):
"""
+ # tvm-ffi-stubgen(begin): object/ffi.Map
+ # fmt: off
+ # fmt: on
+ # tvm-ffi-stubgen(end)
+
def __init__(self, input_dict: Mapping[K, V]) -> None:
"""Construct a Map from a Python mapping."""
list_kvs: list[Any] = []
diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py
index d863ae9..5aac8f0 100644
--- a/python/tvm_ffi/module.py
+++ b/python/tvm_ffi/module.py
@@ -16,13 +16,14 @@
# under the License.
"""Module related objects and functions."""
-# tvm-ffi-stubgen(begin): import
+# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
-from typing import Any, TYPE_CHECKING
+from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Sequence
+ from typing import Any
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
diff --git a/python/tvm_ffi/stub/analysis.py b/python/tvm_ffi/stub/analysis.py
deleted file mode 100644
index 03dbe36..0000000
--- a/python/tvm_ffi/stub/analysis.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Analysis utilities."""
-
-from __future__ import annotations
-
-from tvm_ffi.registry import list_global_func_names
-
-from . import consts as C
-from .utils import FuncInfo
-
-
-def collect_global_funcs() -> dict[str, list[FuncInfo]]:
- """Collect global functions from TVM FFI's global registry."""
- # Build global function table only if we are going to process blocks.
- global_funcs: dict[str, list[FuncInfo]] = {}
- for name in list_global_func_names():
- try:
- prefix, _ = name.rsplit(".", 1)
- except ValueError:
- print(f"{C.TERM_YELLOW}[Skipped] Invalid name in global function:
{name}{C.TERM_RESET}")
- else:
- global_funcs.setdefault(prefix,
[]).append(FuncInfo.from_global_name(name))
- # Ensure stable ordering for deterministic output.
- for k in list(global_funcs.keys()):
- global_funcs[k].sort(key=lambda x: x.schema.name)
- return global_funcs
diff --git a/python/tvm_ffi/stub/cli.py b/python/tvm_ffi/stub/cli.py
index 85697f5..9b9786c 100644
--- a/python/tvm_ffi/stub/cli.py
+++ b/python/tvm_ffi/stub/cli.py
@@ -14,37 +14,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# tvm-ffi-stubgen(skip-file)
"""TVM-FFI Stub Generator (``tvm-ffi-stubgen``)."""
from __future__ import annotations
import argparse
import ctypes
+import importlib
import sys
import traceback
from pathlib import Path
-from typing import Callable
from . import codegen as G
from . import consts as C
-from .analysis import collect_global_funcs
from .file_utils import FileInfo, collect_files
-from .utils import Options
-
-
-def _fn_ty_map(ty_map: dict[str, str], ty_used: set[str]) -> Callable[[str],
str]:
- def _run(name: str) -> str:
- nonlocal ty_map, ty_used
- if (ret := ty_map.get(name)) is not None:
- name = ret
- if (ret := C.TY_TO_IMPORT.get(name)) is not None:
- name = ret
- if "." in name:
- ty_used.add(name)
- return name.rsplit(".", 1)[-1]
-
- return _run
+from .lib_state import (
+ collect_global_funcs,
+ collect_type_keys,
+ object_info_from_type_key,
+ toposort_objects,
+)
+from .utils import FuncInfo, ImportItem, InitConfig, Options
def __main__() -> int:
@@ -55,72 +45,50 @@ def __main__() -> int:
overview and examples of the block syntax.
"""
opt = _parse_args()
+ for imp in opt.imports or []:
+ importlib.import_module(imp)
dlls = [ctypes.CDLL(lib) for lib in opt.dlls]
files: list[FileInfo] = collect_files([Path(f) for f in opt.files])
+ global_funcs: dict[str, list[FuncInfo]] = collect_global_funcs()
+ init_path: Path | None = None
+ if opt.files:
+ init_path = Path(opt.files[0]).resolve()
+ if init_path.is_file():
+ init_path = init_path.parent
- # Stage 1: Process `tvm-ffi-stubgen(ty-map)`
+ # Stage 1: Collect information
+ # - type maps: `tvm-ffi-stubgen(ty-map)`
+ # - defined global functions: `tvm-ffi-stubgen(begin): global/...`
+ # - defined object types: `tvm-ffi-stubgen(begin): object/...`
ty_map: dict[str, str] = C.TY_MAP_DEFAULTS.copy()
-
- def _stage_1(file: FileInfo) -> None:
- for code in file.code_blocks:
- if code.kind == "ty-map":
- try:
- lhs, rhs = code.param.split("->")
- except ValueError as e:
- raise ValueError(
- f"Invalid ty_map format at line {code.lineno_start}.
Example: `A.B -> C.D`"
- ) from e
- ty_map[lhs.strip()] = rhs.strip()
-
for file in files:
try:
- _stage_1(file)
+ _stage_1(file, ty_map)
except Exception:
print(
f'{C.TERM_RED}[Failed] File "{file.path}":
{traceback.format_exc()}{C.TERM_RESET}'
)
- # Stage 2: Process
+ # Stage 2. Generate stubs if they are not defined on the file.
+ if opt.init:
+ assert init_path is not None, "init-path could not be determined"
+ _stage_2(
+ files,
+ ty_map,
+ init_cfg=opt.init,
+ init_path=init_path,
+ global_funcs=global_funcs,
+ )
+
+ # Stage 3: Process
# - `tvm-ffi-stubgen(begin): global/...`
# - `tvm-ffi-stubgen(begin): object/...`
- global_funcs = collect_global_funcs()
-
- def _stage_2(file: FileInfo) -> None:
- all_defined = set()
+ for file in files:
if opt.verbose:
print(f"{C.TERM_CYAN}[File] {file.path}{C.TERM_RESET}")
- ty_used: set[str] = set()
- ty_on_file: set[str] = set()
- fn_ty_map_fn = _fn_ty_map(ty_map, ty_used)
- # Stage 2.1. Process `tvm-ffi-stubgen(begin): global/...`
- for code in file.code_blocks:
- if code.kind == "global":
- funcs = global_funcs.get(code.param, [])
- for func in funcs:
- all_defined.add(func.schema.name)
- G.generate_global_funcs(code, funcs, fn_ty_map_fn, opt)
- # Stage 2.2. Process `tvm-ffi-stubgen(begin): object/...`
- for code in file.code_blocks:
- if code.kind == "object":
- type_key = code.param
- ty_on_file.add(ty_map.get(type_key, type_key))
- G.generate_object(code, fn_ty_map_fn, opt)
- # Stage 2.3. Add imports for used types.
- for code in file.code_blocks:
- if code.kind == "import":
- G.generate_imports(code, ty_used - ty_on_file, opt)
- break # Only one import block per file is supported for now.
- # Stage 2.4. Add `__all__` for defined classes and functions.
- for code in file.code_blocks:
- if code.kind == "__all__":
- G.generate_all(code, all_defined | ty_on_file, opt)
- break # Only one __all__ block per file is supported for now.
- file.update(show_diff=opt.verbose, dry_run=opt.dry_run)
-
- for file in files:
try:
- _stage_2(file)
- except:
+ _stage_3(file, opt, ty_map, global_funcs)
+ except Exception:
print(
f'{C.TERM_RED}[Failed] File "{file.path}":
{traceback.format_exc()}{C.TERM_RESET}'
)
@@ -128,56 +96,379 @@ def __main__() -> int:
return 0
+def _stage_1(
+ file: FileInfo,
+ ty_map: dict[str, str],
+) -> None:
+ for code in file.code_blocks:
+ if code.kind == "ty-map":
+ try:
+ assert isinstance(code.param, str)
+ lhs, rhs = code.param.split("->")
+ except ValueError as e:
+ raise ValueError(
+ f"Invalid ty_map format at line {code.lineno_start}.
Example: `A.B -> C.D`"
+ ) from e
+ ty_map[lhs.strip()] = rhs.strip()
+
+
+def _stage_2(
+ files: list[FileInfo],
+ ty_map: dict[str, str],
+ init_cfg: InitConfig,
+ init_path: Path,
+ global_funcs: dict[str, list[FuncInfo]],
+) -> None:
+ def _find_or_insert_file(path: Path) -> FileInfo:
+ ret: FileInfo | None
+ if not path.exists():
+ ret = FileInfo(path=path, lines=(), code_blocks=[])
+ else:
+ for file in files:
+ if path.samefile(file.path):
+ return file
+ ret = FileInfo.from_file(file=path, include_empty=True)
+ assert ret is not None, f"Failed to read file: {path}"
+ files.append(ret)
+ return ret
+
+ # Step 0. Find out functions and classes already defined on files.
+ defined_func_prefixes: set[str] = { # type: ignore[union-attr]
+ code.param[0] for file in files for code in file.code_blocks if
code.kind == "global"
+ }
+ defined_objs: set[str] = { # type: ignore[assignment]
+ code.param for file in files for code in file.code_blocks if code.kind
== "object"
+ } | C.BUILTIN_TYPE_KEYS
+
+ # Step 0. Generate missing `_ffi_api.py` and `__init__.py` under each
prefix.
+ prefixes: dict[str, list[str]] = collect_type_keys()
+ for prefix in global_funcs:
+ prefixes.setdefault(prefix, [])
+
+ root_ffi_api_py = init_path / init_cfg.prefix.rstrip(".") / "_ffi_api.py"
+ for prefix, obj_names in prefixes.items():
+ # TODO(@junrushao): control the prefix to generate stubs for
+ if prefix.startswith("testing") or prefix.startswith("ffi"):
+ continue
+ funcs = sorted(
+ [] if prefix in defined_func_prefixes else
global_funcs.get(prefix, []),
+ key=lambda f: f.schema.name,
+ )
+ objs = sorted(set(obj_names) - defined_objs)
+ object_infos = toposort_objects(objs)
+ if not funcs and not object_infos:
+ continue
+ # Step 1. Create target directory if not exists
+ directory = init_path / prefix.replace(".", "/")
+ directory.mkdir(parents=True, exist_ok=True)
+ # Step 2. Generate `_ffi_api.py`
+ target_path = directory / "_ffi_api.py"
+ target_file = _find_or_insert_file(target_path)
+ with target_path.open("a", encoding="utf-8") as f:
+ f.write(
+ G.generate_ffi_api(
+ target_file.code_blocks,
+ ty_map,
+ prefix,
+ object_infos,
+ init_cfg,
+ is_root=root_ffi_api_py.samefile(target_path),
+ )
+ )
+ target_file.reload()
+ # Step 3. Generate `__init__.py`
+ target_path = directory / "__init__.py"
+ target_file = _find_or_insert_file(target_path)
+ with target_path.open("a", encoding="utf-8") as f:
+ f.write(G.generate_init(target_file.code_blocks, prefix,
submodule="_ffi_api"))
+ target_file.reload()
+
+
+def _stage_3( # noqa: PLR0912
+ file: FileInfo,
+ opt: Options,
+ ty_map: dict[str, str],
+ global_funcs: dict[str, list[FuncInfo]],
+) -> None:
+ defined_funcs: set[str] = set()
+ defined_types: set[str] = set()
+ imports: list[ImportItem] = []
+ ffi_load_lib_imported = False
+ # Stage 1. Collect `tvm-ffi-stubgen(import-object): ...`
+ for code in file.code_blocks:
+ if code.kind == "import-object":
+ name, type_checking_only, alias = code.param # type: ignore[misc]
+ imports.append(
+ ImportItem(
+ name,
+ type_checking_only=(
+ bool(type_checking_only)
+ and isinstance(type_checking_only, str)
+ and type_checking_only.lower() == "true"
+ ),
+ alias=alias if alias else None,
+ )
+ )
+ if (alias and alias == "_FFI_LOAD_LIB") or
name.endswith("libinfo.load_lib_module"):
+ ffi_load_lib_imported = True
+ # Stage 2. Process `tvm-ffi-stubgen(begin): global/...`
+ for code in file.code_blocks:
+ if code.kind == "global":
+ funcs = global_funcs.get(code.param[0], [])
+ for func in funcs:
+ defined_funcs.add(func.schema.name)
+ G.generate_global_funcs(code, funcs, ty_map, imports, opt)
+ # Stage 3. Process `tvm-ffi-stubgen(begin): object/...`
+ for code in file.code_blocks:
+ if code.kind == "object":
+ type_key = code.param
+ assert isinstance(type_key, str)
+ obj_info = object_info_from_type_key(type_key)
+ type_key = ty_map.get(type_key, type_key)
+ full_name = ImportItem(type_key).full_name
+ defined_types.add(full_name)
+ G.generate_object(code, ty_map, imports, opt, obj_info)
+ # Stage 4. Add imports for used types.
+ imports = [i for i in imports if i.full_name not in defined_types]
+ for code in file.code_blocks:
+ if code.kind == "import-section":
+ G.generate_import_section(code, imports, opt)
+ break # Only one import block per file is supported for now.
+ # Stage 5. Add `__all__` for defined classes and functions.
+ for code in file.code_blocks:
+ if code.kind == "__all__":
+ export_names = defined_funcs | defined_types
+ if ffi_load_lib_imported:
+ export_names = export_names | {"LIB"}
+ G.generate_all(code, export_names, opt)
+ break # Only one __all__ block per file is supported for now.
+ # Stage 6. Process `tvm-ffi-stubgen(begin): export/...`
+ for code in file.code_blocks:
+ if code.kind == "export":
+ G.generate_export(code)
+ # Finalize: write back to file
+ file.update(verbose=opt.verbose, dry_run=opt.dry_run)
+
+
def _parse_args() -> Options:
class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter,
argparse.RawTextHelpFormatter):
pass
+ def _split_list_arg(arg: str | None) -> list[str]:
+ if not arg:
+ return []
+ return [item.strip() for item in arg.split(";") if item.strip()]
+
parser = argparse.ArgumentParser(
prog="tvm-ffi-stubgen",
description=(
- "Generate in-place type stubs for TVM FFI.\n\n"
- "It scans .py/.pyi files for tvm-ffi-stubgen blocks and fills them
with\n"
- "TYPE_CHECKING-only annotations derived from TVM runtime metadata."
+ "Generate type stubs for TVM FFI extensions.\n\n"
+ "In `--init-*` mode, it generates missing `_ffi_api.py` and
`__init__.py` files, "
+ "based on the registered global functions and object types in the
loaded libraries.\n\n"
+ "In normal mode, it processes the given files/directories
in-place, generating "
+ "type stubs inside special `tvm-ffi-stubgen` blocks. Scroo down
for more details."
),
formatter_class=HelpFormatter,
epilog=(
- "Examples:\n"
+ "========\n"
+ "Examples\n"
+ "========\n\n"
" # Single file\n"
" tvm-ffi-stubgen python/tvm_ffi/_ffi_api.py\n\n"
" # Recursively scan directories\n"
" tvm-ffi-stubgen python/tvm_ffi
examples/packaging/python/my_ffi_extension\n\n"
- " # Preload TVM runtime / extension libraries\n"
- " tvm-ffi-stubgen --dlls build/libtvm_runtime.so
build/libmy_ext.so my_pkg/_ffi_api.py\n\n"
- "Stub block syntax (placed in your source):\n"
- " # tvm-ffi-stubgen(begin): global/<registry-prefix>\n"
- " ... generated function stubs ...\n"
- " # tvm-ffi-stubgen(end)\n\n"
- " # tvm-ffi-stubgen(begin): object/<type_key>\n"
- " # tvm-ffi-stubgen(ty_map): list -> Sequence\n"
- " # tvm-ffi-stubgen(ty_map): dict -> Mapping\n"
- " ... generated fields and methods ...\n"
- " # tvm-ffi-stubgen(end)\n\n"
- " # Skip a file entirely\n"
- " # tvm-ffi-stubgen(skip-file)\n\n"
- "Tips:\n"
- " - Only .py/.pyi files are updated; directories are scanned
recursively.\n"
- " - Import any aliases you use in ty_map under TYPE_CHECKING,
e.g.\n"
- " from collections.abc import Mapping, Sequence\n"
- " - Use --dlls to preload shared libraries when function/type
metadata\n"
- " is provided by native extensions.\n"
+ " # Preload extension libraries\n"
+ " tvm-ffi-stubgen --dlls
build/libmy_ext.so;build/libmy_2nd_ext.so my_pkg/_ffi_api.py\n\n"
+ " # Package-level init (my-ffi-extension)\n"
+ " tvm-ffi-stubgen examples/packaging/python \\\n"
+ " --dlls examples/packaging/build/libmy_ffi_extension.dylib
\\\n"
+ " --init-pypkg my-ffi-extension \\\n"
+ " --init-lib my_ffi_extension \\\n"
+ ' --init-prefix "my_ffi_extension."\n\n'
+ "=====================\n"
+ "Syntax of stub blocks\n"
+ "=====================\n\n"
+ "Global functions\n"
+ "~~~~~~~~~~~~~~~~\n\n"
+ " ```\n"
+ f" {C.STUB_BEGIN} global/<registry-prefix>@<import-from
(default: tvm_ffi)>\n"
+ f" {C.STUB_END}\n"
+ " ```\n\n"
+ "Generates TYPE_CHECKING-only stubs for functions in the global
registry under the prefix.\n\n"
+ "Example:\n\n"
+ " ```\n"
+ f" {C.STUB_BEGIN} global/[email protected]\n"
+ " # fmt: off\n"
+ ' _FFI_INIT_FUNC("ffi", __name__)\n'
+ " if TYPE_CHECKING:\n"
+ " def Array(*args: Any) -> Any: ...\n"
+ " def ArrayGetItem(_0: Sequence[Any], _1: int, /) -> Any:
...\n"
+ " def ArraySize(_0: Sequence[Any], /) -> int: ...\n"
+ " def Bytes(_0: bytes, /) -> bytes: ...\n"
+ " ...\n"
+ " def StructuralHash(_0: Any, _1: bool, _2: bool, /) ->
int: ...\n"
+ " def SystemLib(*args: Any) -> Any: ...\n"
+ " def ToJSONGraph(_0: Any, _1: Any, /) -> Any: ...\n"
+ " def ToJSONGraphString(_0: Any, _1: Any, /) -> str: ...\n"
+ " # fmt: on\n"
+ f" {C.STUB_END}\n"
+ " ```\n\n"
+ "Objects\n"
+ "~~~~~~~\n\n"
+ " ```\n"
+ f" {C.STUB_BEGIN} object/<type_key>\n"
+ f" {C.STUB_END}\n"
+ " ```\n\n"
+ "Generates fields/methods for a class defined using TVM-FFI Object
APIs.\n\n"
+ "Example:\n\n"
+ " ```\n"
+ ' @register_object("ffi.reflection.AccessPath")\n'
+ " class AccessPath(tvm_ffi.Object):\n"
+ f" {C.STUB_BEGIN} object/ffi.reflection.AccessPath\n"
+ " # fmt: off\n"
+ " parent: Object | None\n"
+ " step: AccessStep | None\n"
+ " depth: int\n"
+ " if TYPE_CHECKING:\n"
+ " @staticmethod\n"
+ " def _root() -> AccessPath: ...\n"
+ " def _extend(self, _1: AccessStep, /) -> AccessPath:
...\n"
+ " def _attr(self, _1: str, /) -> AccessPath: ...\n"
+ " def _array_item(self, _1: int, /) -> AccessPath:
...\n"
+ " def _map_item(self, _1: Any, /) -> AccessPath: ...\n"
+ " def _attr_missing(self, _1: str, /) -> AccessPath:
...\n"
+ " def _array_item_missing(self, _1: int, /) ->
AccessPath: ...\n"
+ " def _map_item_missing(self, _1: Any, /) ->
AccessPath: ...\n"
+ " def _is_prefix_of(self, _1: AccessPath, /) -> bool:
...\n"
+ " def _to_steps(self, /) -> Sequence[AccessStep]: ...\n"
+ " def _path_equal(self, _1: AccessPath, /) -> bool:
...\n"
+ " # fmt: on\n"
+ f" {C.STUB_END}\n"
+ " ```\n\n"
+ "Import section\n"
+ "~~~~~~~~~~~~~~\n\n"
+ " ```\n"
+ f" {C.STUB_BEGIN} import-section\n"
+ " # fmt: off\n"
+ " # isort: off\n"
+ " from __future__ import annotations\n"
+ " from ..registry import init_ffi_api as _FFI_INIT_FUNC\n"
+ " from typing import TYPE_CHECKING\n"
+ " if TYPE_CHECKING:\n"
+ " from collections.abc import Mapping, Sequence\n"
+ " from tvm_ffi import Device, Object, Tensor, dtype\n"
+ " from tvm_ffi.testing import TestIntPair\n"
+ " from typing import Any, Callable\n"
+ " # isort: on\n"
+ " # fmt: on\n"
+ f" {C.STUB_END}\n"
+ " ```\n\n"
+ "Auto-populates imports used by generated stubs.\n\n"
+ "Export\n"
+ "~~~~~~\n\n"
+ " ```\n"
+ f" {C.STUB_BEGIN} export/_ffi_api\n"
+ " # fmt: off\n"
+ " # isort: off\n"
+ " from ._ffi_api import * # noqa: F403\n"
+ " from ._ffi_api import __all__ as _ffi_api__all__\n"
+ ' if "__all__" not in globals():\n'
+ " __all__ = []\n"
+ " __all__.extend(_ffi_api__all__)\n"
+ " # isort: on\n"
+ " # fmt: on\n"
+ f" {C.STUB_END}\n"
+ " ```\n\n"
+ "Re-exports a generated submodule's __all__ into the parent.\n\n"
+ "__all__\n"
+ "~~~~~~~\n\n"
+ " ```\n"
+ " __all__ = [\n"
+ f" {C.STUB_BEGIN} __all__\n"
+ ' "LIB",\n'
+ ' "IntPair",\n'
+ ' "raise_error",\n'
+ f" {C.STUB_END}\n"
+ " ]\n"
+ " ```\n\n"
+ "Populates __all__ with generated classes/functions and LIB (if
present).\n\n"
+ "Type map\n"
+ "~~~~~~~~\n\n"
+ " ```\n"
+ f" {C.STUB_TY_MAP} <type_key> -> <python_type>\n"
+ " ```\n\n"
+ "Maps runtime type keys to Python types used in generation.\n\n"
+ "Example:\n\n"
+ " ```\n"
+ f" {C.STUB_TY_MAP} ffi.reflection.AccessStep ->
ffi.access_path.AccessStep\n"
+ " ```\n\n"
+ "Import object\n"
+ "~~~~~~~~~~~~~\n\n"
+ " ```\n"
+ f" {C.STUB_IMPORT_OBJECT} <from>; <type_checking_only>;
<alias>\n"
+ " ```\n\n"
+ "Injects a custom import into generated code, optionally
TYPE_CHECKING-only.\n\n"
+ "Example:\n\n"
+ " ```\n"
+ f" {C.STUB_IMPORT_OBJECT} ffi.Object;False;_ffi_Object\n"
+ " ```\n\n"
+ "Skip file\n"
+ "~~~~~~~~~\n\n"
+ " ```\n"
+ f" {C.STUB_SKIP_FILE}\n"
+ " ```\n\n"
+ "Prevents stubgen from modifying the file."
+ ),
+ )
+ parser.add_argument(
+ "--imports",
+ type=str,
+ default="",
+ metavar="IMPORTS",
+ help=(
+ "Additional imports to load before generation, separated by ';' "
+ "(e.g. 'pkgA;pkgB.submodule')."
),
)
parser.add_argument(
"--dlls",
- nargs="*",
- metavar="LIB",
+ type=str,
+ default="",
+ metavar="LIBS",
help=(
"Shared libraries to preload before generation (e.g. TVM runtime
or "
- "your extension). This ensures global function and object metadata
"
- "is available. Accepts multiple paths; platform-specific suffixes "
- "like .so/.dylib/.dll are supported."
+ "your extension), separated by ';'. This ensures global function
and "
+ "object metadata is available. Platform-specific suffixes like "
+ ".so/.dylib/.dll are supported."
+ ),
+ )
+ parser.add_argument(
+ "--init-pypkg",
+ type=str,
+ default="",
+ help=(
+ "Python package name to generate stubs for (e.g. apache-tvm-ffi). "
+ "Required together with --init-lib, --init-path, and
--init-prefix."
+ ),
+ )
+ parser.add_argument(
+ "--init-lib",
+ type=str,
+ default="",
+ help=(
+ "CMake target that produces the shared library to load for stub
generation "
+ "(e.g. tvm_ffi_shared). Required together with --init-pypkg and "
+ "--init-prefix."
+ ),
+ )
+ parser.add_argument(
+ "--init-prefix",
+ type=str,
+ default="",
+ help=(
+ "Global function/object prefix to include when generating stubs "
+ "(e.g. tvm_ffi.). Required together with --init-pypkg and
--init-lib."
),
- default=[],
)
parser.add_argument(
"--indent",
@@ -185,7 +476,7 @@ def _parse_args() -> Options:
default=4,
help=(
"Extra spaces added inside each generated block, relative to the "
- "indentation of the corresponding '# tvm-ffi-stubgen(begin):'
line."
+ f"indentation of the corresponding '{C.STUB_BEGIN}' line."
),
)
parser.add_argument(
@@ -214,11 +505,32 @@ def _parse_args() -> Options:
"without modifying any files."
),
)
- opt = Options(**vars(parser.parse_args()))
- if not opt.files:
+ args = parser.parse_args()
+
+ init_flags = [args.init_pypkg, args.init_lib, args.init_prefix]
+ init_cfg: InitConfig | None = None
+ if any(init_flags):
+ if not all(init_flags):
+ parser.error("--init-pypkg, --init-lib, and --init-prefix must be
provided together")
+ init_cfg = InitConfig(
+ pkg=args.init_pypkg,
+ shared_target=args.init_lib,
+ prefix=args.init_prefix,
+ )
+
+ if not args.files:
parser.print_help()
sys.exit(1)
- return opt
+
+ return Options(
+ imports=_split_list_arg(args.imports),
+ dlls=_split_list_arg(args.dlls),
+ init=init_cfg,
+ indent=args.indent,
+ files=args.files,
+ verbose=args.verbose,
+ dry_run=args.dry_run,
+ )
if __name__ == "__main__":
diff --git a/python/tvm_ffi/stub/codegen.py b/python/tvm_ffi/stub/codegen.py
index c15624a..45e6231 100644
--- a/python/tvm_ffi/stub/codegen.py
+++ b/python/tvm_ffi/stub/codegen.py
@@ -22,26 +22,59 @@ from typing import Callable
from . import consts as C
from .file_utils import CodeBlock
-from .utils import FuncInfo, ObjectInfo, Options
+from .utils import FuncInfo, ImportItem, InitConfig, ObjectInfo, Options
+
+
+def _type_suffix_and_record(
+ ty_map: dict[str, str], imports: list[ImportItem]
+) -> Callable[[str], str]:
+ def _run(name: str) -> str:
+ nonlocal ty_map, imports
+ name = ty_map.get(name, name)
+ if "." in name:
+ imports.append(ImportItem(name, type_checking_only=True,
alias=None))
+ return name.rsplit(".", 1)[-1]
+
+ return _run
def generate_global_funcs(
- code: CodeBlock, global_funcs: list[FuncInfo], fn_ty_map: Callable[[str],
str], opt: Options
+ code: CodeBlock,
+ global_funcs: list[FuncInfo],
+ ty_map: dict[str, str],
+ imports: list[ImportItem],
+ opt: Options,
) -> None:
- """Generate function signatures for global functions."""
+ """Generate function signatures for global functions.
+
+ It processes: global/${prefix}@${import_from="tvm_ffi")
+ """
assert len(code.lines) >= 2
if not global_funcs:
return
+ assert isinstance(code.param, tuple)
+ prefix, import_from = code.param
+ if not import_from:
+ import_from = "tvm_ffi"
+ imports.extend(
+ [
+ ImportItem(
+ f"{import_from}.init_ffi_api",
+ type_checking_only=False,
+ alias="_FFI_INIT_FUNC",
+ ),
+ ImportItem(
+ "typing.TYPE_CHECKING",
+ type_checking_only=False,
+ ),
+ ]
+ )
+ fn_ty_map = _type_suffix_and_record(ty_map, imports)
results: list[str] = [
"# fmt: off",
+ f'_FFI_INIT_FUNC("{prefix}", __name__)',
"if TYPE_CHECKING:",
- *[
- func.gen(
- fn_ty_map,
- indent=opt.indent,
- )
- for func in global_funcs
- ],
+ *[func.gen(fn_ty_map, indent=opt.indent) for func in global_funcs],
"# fmt: on",
]
indent = " " * code.indent
@@ -52,11 +85,27 @@ def generate_global_funcs(
]
-def generate_object(code: CodeBlock, fn_ty_map: Callable[[str], str], opt:
Options) -> None:
- """Generate a class definition for an object type."""
+def generate_object(
+ code: CodeBlock,
+ ty_map: dict[str, str],
+ imports: list[ImportItem],
+ opt: Options,
+ obj_info: ObjectInfo,
+) -> None:
+ """Generate a class definition for an object type.
+
+ It processes: object/${type_key}
+ """
assert len(code.lines) >= 2
- info = ObjectInfo.from_type_key(code.param)
+ info = obj_info
+ fn_ty_map = _type_suffix_and_record(ty_map, imports)
if info.methods:
+ imports.append(
+ ImportItem(
+ "typing.TYPE_CHECKING",
+ type_checking_only=False,
+ )
+ )
results = [
"# fmt: off",
*info.gen_fields(fn_ty_map, indent=0),
@@ -78,43 +127,52 @@ def generate_object(code: CodeBlock, fn_ty_map:
Callable[[str], str], opt: Optio
]
-def generate_imports(code: CodeBlock, ty_used: set[str], opt: Options) -> None:
- """Generate import statements for the types used in the stub."""
- ty_collected: dict[str, list[str]] = {}
- for ty in ty_used:
- assert "." in ty
- module, name = ty.rsplit(".", 1)
- for mod_prefix, mod_replacement in C.MOD_MAP.items():
- if module.startswith(mod_prefix):
- module = module.replace(mod_prefix, mod_replacement, 1)
- break
- ty_collected.setdefault(module, []).append(name)
- if not ty_collected:
- return
+def generate_import_section(
+ code: CodeBlock,
+ imports: list[ImportItem],
+ opt: Options,
+) -> None:
+ """Generate import statements for the types used in the stub.
+
+ It processes: import-section
+ """
+ imports_concrete: dict[str, list[ImportItem]] = {}
+ imports_ty_check: dict[str, list[ImportItem]] = {}
+ for item in imports:
+ if item.type_checking_only:
+ imports_ty_check.setdefault(item.mod, []).append(item)
+ else:
+ imports_concrete.setdefault(item.mod, []).append(item)
+ if imports_ty_check:
+ imports_concrete.setdefault("typing", []).append(
+ ImportItem("typing.TYPE_CHECKING", type_checking_only=True)
+ )
- def _make_line(module: str, names: list[str], indent: int) -> str:
- names = ", ".join(sorted(set(names)))
+ def _make_line(mod: str, items: list[ImportItem], indent: int) -> str:
+ items.sort(key=lambda item: item.name)
+ names = ", ".join(sorted(set(item.name_with_alias for item in items)))
indent_str = " " * indent
- return f"{indent_str}from {module} import {names}"
+ if mod:
+ return f"{indent_str}from {mod} import {names}"
+ else:
+ return f"{indent_str}import {names}"
- results: list[str] = [
- "from __future__ import annotations",
- _make_line(
- "typing",
- [*ty_collected.pop("typing", []), "TYPE_CHECKING"],
- indent=0,
- ),
- ]
- if ty_collected:
+ results: list[str] = []
+ if imports_concrete:
+ results.extend(
+ _make_line(mod, imports_concrete[mod], indent=0) for mod in
sorted(imports_concrete)
+ )
+ if imports_ty_check:
results.append("if TYPE_CHECKING:")
- for module in sorted(ty_collected):
- names = ty_collected[module]
- results.append(_make_line(module, names, indent=opt.indent))
+ results.extend(
+ _make_line(mod, imports_ty_check[mod], opt.indent) for mod in
sorted(imports_ty_check)
+ )
if results:
code.lines = [
code.lines[0],
"# fmt: off",
"# isort: off",
+ "from __future__ import annotations",
*results,
"# isort: on",
"# fmt: on",
@@ -130,8 +188,114 @@ def generate_all(code: CodeBlock, names: set[str], opt:
Options) -> None:
indent = " " * code.indent
names = {f.rsplit(".", 1)[-1] for f in names}
+
+ def _sort_key(name: str) -> tuple[int, str]:
+ if name.isupper():
+ return (0, name)
+ if name and name[0].isupper() and not "_" in name:
+ return (1, name)
+ return (2, name)
+
+ code.lines = [
+ code.lines[0],
+ *[f'{indent}"{name}",' for name in sorted(names, key=_sort_key)],
+ code.lines[-1],
+ ]
+
+
+def generate_export(code: CodeBlock) -> None:
+ """Generate an `__all__` variable for the given names."""
+ assert len(code.lines) >= 2
+
+ mod = code.param
code.lines = [
code.lines[0],
- *[f'{indent}"{name}",' for name in sorted(names)],
+ "# fmt: off",
+ "# isort: off",
+ f"from .{mod} import * # noqa: F403",
+ f"from .{mod} import __all__ as {mod}__all__",
+ 'if "__all__" not in globals():',
+ " __all__ = []",
+ f"__all__.extend({mod}__all__)",
+ "# isort: on",
+ "# fmt: on",
code.lines[-1],
]
+
+
+def generate_ffi_api(
+ code_blocks: list[CodeBlock],
+ ty_map: dict[str, str],
+ module_name: str,
+ object_infos: list[ObjectInfo],
+ init_cfg: InitConfig,
+ is_root: bool,
+) -> str:
+ """Generate the initial FFI API stub code for a given module."""
+ # TODO(@junrus): New code is appended to the end of the file.
+ # We should consider a more sophisticated approach.
+ append = ""
+
+ # Part 0. Imports
+ if not code_blocks:
+ append += f"""\"\"\"FFI API bindings for {module_name}.\"\"\"\n"""
+ if not any(code.kind == "import-section" for code in code_blocks):
+ append += C.PROMPT_IMPORT_SECTION
+
+ # Part 1. Library loading
+ if is_root:
+ append += C._prompt_import_object("tvm_ffi.libinfo.load_lib_module",
"_FFI_LOAD_LIB")
+ append += f"""LIB = _FFI_LOAD_LIB("{init_cfg.pkg}",
"{init_cfg.shared_target}")\n"""
+
+ # Part 2. Global functions
+ if not any(code.kind == "global" for code in code_blocks):
+ append += C._prompt_globals(module_name)
+
+ # Part 3. Object types
+ if object_infos:
+ append += C._prompt_import_object("tvm_ffi.register_object",
"_FFI_REG_OBJ")
+
+ defined_type_keys = {info.type_key for info in object_infos if
info.type_key}
+ for info in object_infos:
+ type_key = info.type_key
+ parent_type_key = info.parent_type_key
+ if type_key is None:
+ continue
+ # Canonicalize type key names
+ type_key = ty_map.get(type_key, type_key)
+ type_name = type_key.rsplit(".", 1)[-1]
+ parent_type_key = (
+ ty_map.get(parent_type_key, parent_type_key) if parent_type_key
else parent_type_key
+ )
+ parent_type_name = parent_type_key.rsplit(".", 1)[-1] if
parent_type_key else "Object"
+ # Import parent type keys if they are not defined in the current module
+ if parent_type_key and parent_type_key not in defined_type_keys:
+ parent_type_name = "_" + parent_type_key.replace(".", "_")
+ append += C._prompt_import_object(parent_type_key,
parent_type_name)
+ # Generate class definition
+ append += C._prompt_class_def(
+ type_name,
+ type_key,
+ parent_type_name,
+ )
+ # Part 4. __all__
+ if not any(code.kind == "__all__" for code in code_blocks):
+ append += C.PROMPT_ALL_SECTION
+ return append
+
+
+def generate_init(
+ code_blocks: list[CodeBlock],
+ module_name: str,
+ submodule: str = "_ffi_api",
+) -> str:
+ """Generate the `__init__.py` file for the `tvm_ffi` package."""
+ code = f"""
+{C.STUB_BEGIN} export/{submodule}
+{C.STUB_END}
+"""
+ if not code_blocks:
+ return f"""\"\"\"Package {module_name}.\"\"\"\n""" + code
+ if not any(code.kind == "export" for code in code_blocks):
+ return code
+ return ""
diff --git a/python/tvm_ffi/stub/consts.py b/python/tvm_ffi/stub/consts.py
index 6922254..b6458e1 100644
--- a/python/tvm_ffi/stub/consts.py
+++ b/python/tvm_ffi/stub/consts.py
@@ -16,11 +16,26 @@
# under the License.
"""Constants used in stub generation."""
+from typing import Literal
+
+from typing_extensions import TypeAlias
+
STUB_PREFIX = "# tvm-ffi-stubgen("
STUB_BEGIN = f"{STUB_PREFIX}begin):"
STUB_END = f"{STUB_PREFIX}end)"
STUB_TY_MAP = f"{STUB_PREFIX}ty-map):"
+STUB_IMPORT_OBJECT = f"{STUB_PREFIX}import-object):"
STUB_SKIP_FILE = f"{STUB_PREFIX}skip-file)"
+STUB_BLOCK_KINDS: TypeAlias = Literal[
+ "global",
+ "object",
+ "ty-map",
+ "import-section",
+ "import-object",
+ "export",
+ "__all__",
+ None,
+]
TERM_RESET = "\033[0m"
TERM_BOLD = "\033[1m"
@@ -35,18 +50,15 @@ TERM_WHITE = "\033[37m"
DEFAULT_SOURCE_EXTS = {".py", ".pyi"}
TY_MAP_DEFAULTS = {
- "list": "collections.abc.Sequence",
- "dict": "collections.abc.Mapping",
-}
-
-TY_TO_IMPORT = {
"Any": "typing.Any",
"Callable": "typing.Callable",
"Mapping": "typing.Mapping",
- "Object": "tvm_ffi.Object",
- "Tensor": "tvm_ffi.Tensor",
- "dtype": "tvm_ffi.dtype",
- "Device": "tvm_ffi.Device",
+ "list": "collections.abc.Sequence",
+ "dict": "collections.abc.Mapping",
+ "Object": "ffi.Object",
+ "Tensor": "ffi.Tensor",
+ "dtype": "ffi.dtype",
+ "Device": "ffi.Device",
}
# TODO(@junrushao): Make it configurable
@@ -58,3 +70,47 @@ MOD_MAP = {
FN_NAME_MAP = {
"__ffi_init__": "__c_ffi_init__",
}
+
+BUILTIN_TYPE_KEYS = {
+ "ffi.Bytes",
+ "ffi.Error",
+ "ffi.Function",
+ "ffi.Object",
+ "ffi.OpaquePyObject",
+ "ffi.SmallBytes",
+ "ffi.SmallStr",
+ "ffi.String",
+ "ffi.Tensor",
+}
+
+
+def _prompt_globals(mod: str) -> str:
+ return f"""{STUB_BEGIN} global/{mod}
+{STUB_END}
+"""
+
+
+def _prompt_class_def(type_name: str, type_key: str, parent_type_name: str) ->
str:
+ return f'''@_FFI_REG_OBJ("{type_key}")
+class {type_name}({parent_type_name}):
+ """FFI binding for `{type_key}`."""
+
+ {STUB_BEGIN} object/{type_key}
+ {STUB_END}\n\n'''
+
+
+def _prompt_import_object(type_key: str, type_name: str) -> str:
+ return f"""{STUB_IMPORT_OBJECT} {type_key};False;{type_name}\n"""
+
+
+PROMPT_IMPORT_SECTION = f"""
+{STUB_BEGIN} import-section
+{STUB_END}
+"""
+
+PROMPT_ALL_SECTION = f"""
+__all__ = [
+ {STUB_BEGIN} __all__
+ {STUB_END}
+]
+"""
diff --git a/python/tvm_ffi/stub/file_utils.py
b/python/tvm_ffi/stub/file_utils.py
index f100c55..055c56d 100644
--- a/python/tvm_ffi/stub/file_utils.py
+++ b/python/tvm_ffi/stub/file_utils.py
@@ -20,9 +20,10 @@ from __future__ import annotations
import dataclasses
import difflib
+import os
import traceback
from pathlib import Path
-from typing import Callable, Generator, Iterable, Literal
+from typing import Callable, Generator, Iterable
from . import consts as C
@@ -31,15 +32,24 @@ from . import consts as C
class CodeBlock:
"""A block of code to be generated in a stub file."""
- kind: Literal["global", "object", "ty-map", "import", "__all__", None]
- param: str
+ kind: C.STUB_BLOCK_KINDS
+ param: str | tuple[str, ...]
lineno_start: int
lineno_end: int | None
lines: list[str]
def __post_init__(self) -> None:
"""Validate the code block after initialization."""
- assert self.kind in {"global", "object", "ty-map", "import",
"__all__", None}
+ assert self.kind in {
+ "global",
+ "object",
+ "ty-map",
+ "import-section",
+ "import-object",
+ "export",
+ "__all__",
+ None,
+ }
@property
def indent(self) -> int:
@@ -53,27 +63,48 @@ class CodeBlock:
def from_begin_line(lineo: int, line: str) -> CodeBlock:
"""Parse a line to create a CodeBlock if it contains a stub begin
marker."""
if line.startswith(C.STUB_TY_MAP):
+ line = line[len(C.STUB_TY_MAP) :].strip()
return CodeBlock(
kind="ty-map",
- param=line[len(C.STUB_TY_MAP) :].strip(),
+ param=line,
+ lineno_start=lineo,
+ lineno_end=lineo,
+ lines=[],
+ )
+ elif line.startswith(C.STUB_IMPORT_OBJECT):
+ line = line[len(C.STUB_IMPORT_OBJECT) :].strip()
+ splits = [p.strip() for p in line.split(";")]
+ if len(splits) < 3:
+ splits += [""] * (3 - len(splits))
+ return CodeBlock(
+ kind="import-object",
+ param=tuple(splits),
lineno_start=lineo,
lineno_end=lineo,
lines=[],
)
assert line.startswith(C.STUB_BEGIN)
+ param: str | tuple[str, ...]
stub = line[len(C.STUB_BEGIN) :].strip()
if stub.startswith("global/"):
kind = "global"
param = stub[len("global/") :].strip()
+ if "@" in param:
+ param = tuple(param.split("@"))
+ else:
+ param = (param, "")
elif stub.startswith("object/"):
kind = "object"
param = stub[len("object/") :].strip()
elif stub.startswith("ty-map/"):
kind = "ty-map"
param = stub[len("ty-map/") :].strip()
- elif stub.startswith("import"):
- kind = "import"
+ elif stub == "import-section":
+ kind = "import-section"
param = ""
+ elif stub.startswith("export/"):
+ kind = "export"
+ param = stub[len("export/") :].strip()
elif stub == "__all__":
kind = "__all__"
param = ""
@@ -96,12 +127,14 @@ class FileInfo:
lines: tuple[str, ...]
code_blocks: list[CodeBlock]
- def update(self, show_diff: bool, dry_run: bool) -> bool:
+ def update(self, verbose: bool, dry_run: bool) -> bool:
"""Update the file's lines based on the current code blocks and
optionally show a diff."""
new_lines = tuple(line for block in self.code_blocks for line in
block.lines)
if self.lines == new_lines:
+ if verbose:
+ print(f"{C.TERM_CYAN}-----> Unchanged{C.TERM_RESET}")
return False
- if show_diff:
+ if verbose:
for line in difflib.unified_diff(self.lines, new_lines,
lineterm=""):
# Skip placeholder headers when fromfile/tofile are unspecified
if line.startswith("---") or line.startswith("+++"):
@@ -120,21 +153,18 @@ class FileInfo:
return True
@staticmethod
- def from_file(file: Path) -> FileInfo | None: # noqa: PLR0912
+ def from_file(file: Path, include_empty: bool = False) -> FileInfo | None:
# noqa: PLR0912
"""Parse a file to extract code blocks based on stub markers."""
assert file.is_file(), f"Expected a file, but got: {file}"
file = file.resolve()
has_marker = False
lines: list[str] = file.read_text(encoding="utf-8").splitlines()
- for line_no, line in enumerate(lines, start=1):
+ for _, line in enumerate(lines, start=1):
if line.strip().startswith(C.STUB_SKIP_FILE):
- print(
- f"{C.TERM_YELLOW}[Skipped] skip-file marker found on line
{line_no}: {file}{C.TERM_RESET}"
- )
return None
if line.strip().startswith(C.STUB_PREFIX):
has_marker = True
- if not has_marker:
+ if not has_marker and not include_empty:
return None
del has_marker
@@ -142,31 +172,46 @@ class FileInfo:
code: CodeBlock | None = None
for lineno, line in enumerate(lines, 1):
clean_line = line.strip()
- if clean_line.startswith(C.STUB_BEGIN): # Process "#
tvm-ffi-stubgen(begin)"
+ if clean_line.startswith(C.STUB_BEGIN):
+ # Process "# tvm-ffi-stubgen(begin)"
if code is not None:
raise ValueError(f"Nested stub not permitted at line
{lineno}")
code = CodeBlock.from_begin_line(lineno, clean_line)
code.lineno_start = lineno
code.lines.append(line)
- elif clean_line.startswith(C.STUB_END): # Process "#
tvm-ffi-stubgen(end)"
+ elif clean_line.startswith(C.STUB_END):
+ # Process "# tvm-ffi-stubgen(end)"
if code is None:
raise ValueError(f"Unmatched `{C.STUB_END}` found at line
{lineno}")
code.lineno_end = lineno
code.lines.append(line)
codes.append(code)
code = None
- elif clean_line.startswith(C.STUB_TY_MAP): # Process "#
tvm-ffi-stubgen(ty_map)"
+ elif clean_line.startswith(C.STUB_TY_MAP):
+ # Process "# tvm-ffi-stubgen(ty_map)"
ty_code = CodeBlock.from_begin_line(lineno, clean_line)
ty_code.lineno_end = lineno
ty_code.lines.append(line)
codes.append(ty_code)
del ty_code
+ elif clean_line.startswith(C.STUB_IMPORT_OBJECT):
+ # Process "# tvm-ffi-stubgen(import-object)"
+ imp_code = CodeBlock.from_begin_line(lineno, clean_line)
+ imp_code.lineno_end = lineno
+ imp_code.lines.append(line)
+ codes.append(imp_code)
+ del imp_code
elif clean_line.startswith(C.STUB_PREFIX):
raise ValueError(f"Unknown stub type at line {lineno}:
{clean_line}")
- elif code is None: # Process a plain line outside of any stub
block
+ elif code is None:
+ # Process a plain line outside of any stub block
codes.append(
CodeBlock(
- kind=None, param="", lineno_start=lineno,
lineno_end=lineno, lines=[line]
+ kind=None,
+ param="",
+ lineno_start=lineno,
+ lineno_end=lineno,
+ lines=[line],
)
)
else: # Process a line inside a stub block
@@ -175,6 +220,12 @@ class FileInfo:
raise ValueError("Unclosed stub block at end of file")
return FileInfo(path=file, lines=tuple(lines), code_blocks=codes)
+ def reload(self) -> None:
+ """Reload the code blocks from disk while preserving original
`lines`."""
+ source = FileInfo.from_file(self.path)
+ assert source is not None, f"File no longer exists or valid:
{self.path}"
+ self.code_blocks = source.code_blocks
+
def collect_files(paths: list[Path]) -> list[FileInfo]:
"""Collect all files from the given paths and parse them into FileInfo
objects."""
@@ -220,6 +271,8 @@ def path_walk(
follow_symlinks: bool = False,
) -> Iterable[tuple[Path, list[str], list[str]]]:
"""Compat wrapper for Path.walk (3.12+) with a fallback for < 3.12."""
+ if not p.exists():
+ return
# Python 3.12+ - just delegate to `Path.walk`
if hasattr(p, "walk"):
yield from p.walk( # type: ignore[attr-defined]
@@ -229,8 +282,6 @@ def path_walk(
)
return
# Python < 3.12 - use `os.walk``
- import os # noqa: PLC0415
-
for root_str, dirnames, filenames in os.walk(
p,
topdown=top_down,
diff --git a/python/tvm_ffi/stub/lib_state.py b/python/tvm_ffi/stub/lib_state.py
new file mode 100644
index 0000000..ea0f0bc
--- /dev/null
+++ b/python/tvm_ffi/stub/lib_state.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Stateful helpers for querying TVM FFI runtime metadata."""
+
+from __future__ import annotations
+
+import functools
+import heapq
+from collections import defaultdict
+
+from tvm_ffi._ffi_api import GetRegisteredTypeKeys
+from tvm_ffi.core import TypeSchema,
_lookup_or_register_type_info_from_type_key
+from tvm_ffi.registry import get_global_func_metadata, list_global_func_names
+
+from . import consts as C
+from .utils import FuncInfo, NamedTypeSchema, ObjectInfo
+
+
[email protected]_cache(maxsize=None)
+def object_info_from_type_key(type_key: str) -> ObjectInfo:
+ """Construct an `ObjectInfo` from an object type key."""
+ type_info = _lookup_or_register_type_info_from_type_key(str(type_key))
+ assert type_info.type_key == type_key
+ return ObjectInfo.from_type_info(type_info)
+
+
+def collect_global_funcs() -> dict[str, list[FuncInfo]]:
+ """Collect global functions from TVM FFI's global registry."""
+ global_funcs: dict[str, list[FuncInfo]] = {}
+ for name in list_global_func_names():
+ try:
+ prefix, _ = name.rsplit(".", 1)
+ except ValueError:
+ print(f"{C.TERM_YELLOW}[Skipped] Invalid name in global function:
{name}{C.TERM_RESET}")
+ else:
+ try:
+ global_funcs.setdefault(prefix,
[]).append(_func_info_from_global_name(name))
+ except Exception:
+ print(f"{C.TERM_YELLOW}[Skipped] Function has no type schema:
{name}{C.TERM_RESET}")
+ for k in list(global_funcs.keys()):
+ global_funcs[k].sort(key=lambda x: x.schema.name)
+ return global_funcs
+
+
+def collect_type_keys() -> dict[str, list[str]]:
+ """Collect registered object type keys from TVM FFI's global registry."""
+ global_objects: dict[str, list[str]] = {}
+ for type_key in GetRegisteredTypeKeys():
+ try:
+ prefix, _ = type_key.rsplit(".", 1)
+ except ValueError:
+ pass
+ else:
+ global_objects.setdefault(prefix, []).append(type_key)
+ for k in list(global_objects.keys()):
+ global_objects[k].sort()
+ return global_objects
+
+
+def toposort_objects(type_keys: list[str]) -> list[ObjectInfo]:
+ """Collect ObjectInfo objects for type keys, topologically sorted by
inheritance."""
+ # Remove duplicates while preserving order.
+ unique_type_keys = list(dict.fromkeys(type_keys))
+ infos: dict[str, ObjectInfo] = {
+ type_key: object_info_from_type_key(type_key) for type_key in
unique_type_keys
+ }
+
+ child_types: dict[str, list[str]] = defaultdict(list)
+ in_degree: dict[str, int] = defaultdict(int)
+ for type_key, info in infos.items():
+ parent_type_key = info.parent_type_key
+ if parent_type_key in infos:
+ child_types[parent_type_key].append(type_key)
+ in_degree[type_key] += 1
+ in_degree[parent_type_key] += 0
+ else:
+ in_degree[type_key] += 0
+
+ for children in child_types.values():
+ children.sort()
+
+ queue: list[str] = [ty for ty, deg in in_degree.items() if deg == 0]
+ heapq.heapify(queue)
+ sorted_keys: list[str] = []
+ while queue:
+ type_key = heapq.heappop(queue)
+ sorted_keys.append(type_key)
+ for child_type_key in child_types[type_key]:
+ in_degree[child_type_key] -= 1
+ if in_degree[child_type_key] == 0:
+ heapq.heappush(queue, child_type_key)
+
+ assert len(sorted_keys) == len(infos)
+ return [infos[type_key] for type_key in sorted_keys]
+
+
[email protected]_cache(maxsize=None)
+def _func_info_from_global_name(name: str) -> FuncInfo:
+ """Construct a `FuncInfo` from a global function name."""
+ return FuncInfo(
+ schema=NamedTypeSchema(
+ name=name,
+
schema=TypeSchema.from_json_str(get_global_func_metadata(name)["type_schema"]),
+ ),
+ is_member=False,
+ )
diff --git a/python/tvm_ffi/stub/utils.py b/python/tvm_ffi/stub/utils.py
index e8beb41..e02bbc3 100644
--- a/python/tvm_ffi/stub/utils.py
+++ b/python/tvm_ffi/stub/utils.py
@@ -22,22 +22,106 @@ import dataclasses
from io import StringIO
from typing import Callable
-from tvm_ffi.core import TypeSchema
+from tvm_ffi.core import TypeInfo, TypeSchema
from . import consts as C
[email protected]
+class InitConfig:
+ """Configuration for generating new stubs.
+
+ Examples
+ --------
+ If we are generating type stubs for Python package `my-ffi-extension`,
+ and the CMake target that generates the shared library is
`my_ffi_extension_shared`,
+ then we can run the following command to generate the stubs:
+
+ --init-pypkg my-ffi-extension --init-lib my_ffi_extension_shared
--init-prefix my_ffi_extension.
+
+ """
+
+ pkg: str
+ """Name of the Python package to generate stubs for, e.g. apache-tvm-ffi
(instead of tvm_ffi)"""
+
+ shared_target: str
+ """Name of CMake target that generates the shared library, e.g.
tvm_ffi_shared
+
+ This is used to determine the name of the shared library file.
+ - macOS: lib{shared_target}.dylib or lib{shared_target}.so
+ - Linux: lib{shared_target}.so
+ - Windows: {shared_target}.dll
+ """
+
+ prefix: str
+ """Only generate stubs for global function and objects with the given
prefix, e.g. `tvm_ffi.`"""
+
+
@dataclasses.dataclass
class Options:
"""Command line options for stub generation."""
+ imports: list[str] = dataclasses.field(default_factory=list)
dlls: list[str] = dataclasses.field(default_factory=list)
+ init: InitConfig | None = None
indent: int = 4
files: list[str] = dataclasses.field(default_factory=list)
verbose: bool = False
dry_run: bool = False
[email protected](frozen=True, eq=True)
+class ImportItem:
+ """An import statement item."""
+
+ mod: str
+ name: str
+ type_checking_only: bool = False
+ alias: str | None = None
+
+ def __init__(
+ self,
+ name: str,
+ type_checking_only: bool = False,
+ alias: str | None = None,
+ ) -> None:
+ """Initialize an `ImportItem` with the given module name and optional
alias."""
+ if "." in name:
+ mod, name = name.rsplit(".", 1)
+ for mod_prefix, mod_replacement in C.MOD_MAP.items():
+ if mod.startswith(mod_prefix):
+ mod = mod.replace(mod_prefix, mod_replacement, 1)
+ break
+ else:
+ mod = ""
+ object.__setattr__(self, "mod", mod)
+ object.__setattr__(self, "name", name)
+ object.__setattr__(self, "type_checking_only", type_checking_only)
+ object.__setattr__(self, "alias", alias)
+
+ @property
+ def name_with_alias(self) -> str:
+ """Generate a string of the form `name as alias` if an alias is set,
otherwise just `name`."""
+ return f"{self.name} as {self.alias}" if self.alias else self.name
+
+ @property
+ def full_name(self) -> str:
+ """Generate a string of the form `mod.name` or `name` if no module is
set."""
+ return f"{self.mod}.{self.name}" if self.mod else self.name
+
+ def __repr__(self) -> str:
+ """Generate an import statement string for this item."""
+ return str(self)
+
+ def __str__(self) -> str:
+ """Generate an import statement string for this item."""
+ if self.mod:
+ ret = f"from {self.mod} import {self.name_with_alias}"
+ else:
+ ret = f"import {self.name_with_alias}"
+ return ret
+
+
@dataclasses.dataclass(init=False)
class NamedTypeSchema(TypeSchema):
"""A type schema with an associated name."""
@@ -58,17 +142,9 @@ class FuncInfo:
is_member: bool
@staticmethod
- def from_global_name(name: str) -> FuncInfo:
- """Construct a `FuncInfo` from a string name of this global
function."""
- from tvm_ffi.registry import get_global_func_metadata # noqa: PLC0415
-
- return FuncInfo(
- schema=NamedTypeSchema(
- name=name,
-
schema=TypeSchema.from_json_str(get_global_func_metadata(name)["type_schema"]),
- ),
- is_member=False,
- )
+ def from_schema(name: str, schema: TypeSchema, *, is_member: bool = False)
-> FuncInfo:
+ """Construct a `FuncInfo` from a name and its type schema."""
+ return FuncInfo(schema=NamedTypeSchema(name=name, schema=schema),
is_member=is_member)
def gen(self, ty_map: Callable[[str], str], indent: int) -> str:
"""Generate a function signature string for this function."""
@@ -108,13 +184,15 @@ class ObjectInfo:
fields: list[NamedTypeSchema]
methods: list[FuncInfo]
+ type_key: str | None = None
+ parent_type_key: str | None = None
@staticmethod
- def from_type_key(type_key: str) -> ObjectInfo:
- """Construct an `ObjectInfo` from a type key."""
- from tvm_ffi.core import _lookup_or_register_type_info_from_type_key
# noqa: PLC0415
-
- type_info = _lookup_or_register_type_info_from_type_key(type_key)
+ def from_type_info(type_info: TypeInfo) -> ObjectInfo:
+ """Construct an `ObjectInfo` from a `TypeInfo` instance."""
+ parent_type_key: str | None = None
+ if type_info.parent_type_info is not None:
+ parent_type_key = type_info.parent_type_info.type_key
return ObjectInfo(
fields=[
NamedTypeSchema(
@@ -133,6 +211,8 @@ class ObjectInfo:
)
for method in type_info.methods
],
+ type_key=type_info.type_key,
+ parent_type_key=parent_type_key,
)
def gen_fields(self, ty_map: Callable[[str], str], indent: int) ->
list[str]:
diff --git a/examples/packaging/python/my_ffi_extension/base.py
b/python/tvm_ffi/testing/__init__.py
similarity index 65%
rename from examples/packaging/python/my_ffi_extension/base.py
rename to python/tvm_ffi/testing/__init__.py
index 5f1005e..cd35736 100644
--- a/examples/packaging/python/my_ffi_extension/base.py
+++ b/python/tvm_ffi/testing/__init__.py
@@ -12,10 +12,20 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations.
-# Base logic to load library for extension package
-"""Utilities to locate and load the example extension shared library."""
+# specific language governing permissions and limitations
+# under the License.
+"""Testing utilities."""
-import tvm_ffi
-
-_LIB = tvm_ffi.libinfo.load_lib_module("my-ffi-extension", "my_ffi_extension")
+from .testing import (
+ TestIntPair,
+ TestObjectBase,
+ TestObjectDerived,
+ _SchemaAllTypes,
+ _TestCxxClassBase,
+ _TestCxxClassDerived,
+ _TestCxxClassDerivedDerived,
+ _TestCxxInitSubset,
+ add_one,
+ create_object,
+ make_unregistered_object,
+)
diff --git a/python/tvm_ffi/testing/_ffi_api.py
b/python/tvm_ffi/testing/_ffi_api.py
new file mode 100644
index 0000000..29453ba
--- /dev/null
+++ b/python/tvm_ffi/testing/_ffi_api.py
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""FFI API for namespace `testing`."""
+
+# tvm-ffi-stubgen(begin): import-section
+# fmt: off
+# isort: off
+from __future__ import annotations
+from ..registry import init_ffi_api as _FFI_INIT_FUNC
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from collections.abc import Mapping, Sequence
+ from tvm_ffi import Device, Object, Tensor, dtype
+ from tvm_ffi.testing import TestIntPair
+ from typing import Any, Callable
+# isort: on
+# fmt: on
+# tvm-ffi-stubgen(end)
+
+# tvm-ffi-stubgen(begin): global/[email protected]
+# fmt: off
+_FFI_INIT_FUNC("testing", __name__)
+if TYPE_CHECKING:
+ def TestIntPairSum(_0: TestIntPair, /) -> int: ...
+ def add_one(_0: int, /) -> int: ...
+ def apply(*args: Any) -> Any: ...
+ def echo(*args: Any) -> Any: ...
+ def get_add_one_c_symbol() -> int: ...
+ def get_mlir_add_one_c_symbol() -> int: ...
+ def make_unregistered_object() -> Object: ...
+ def nop(*args: Any) -> Any: ...
+ def object_use_count(_0: Object, /) -> int: ...
+ def optional_tensor_view_has_value(_0: Tensor | None, /) -> bool: ...
+ def run_check_signal(_0: int, /) -> None: ...
+ def schema_arr_map_opt(_0: Sequence[int | None], _1: Mapping[str,
Sequence[int]], _2: str | None, /) -> Mapping[str, Sequence[int]]: ...
+ def schema_id_any(_0: Any, /) -> Any: ...
+ def schema_id_arr(_0: Sequence[Any], /) -> Sequence[Any]: ...
+ def schema_id_arr_int(_0: Sequence[int], /) -> Sequence[int]: ...
+ def schema_id_arr_obj(_0: Sequence[Object], /) -> Sequence[Object]: ...
+ def schema_id_arr_str(_0: Sequence[str], /) -> Sequence[str]: ...
+ def schema_id_bool(_0: bool, /) -> bool: ...
+ def schema_id_bytes(_0: bytes, /) -> bytes: ...
+ def schema_id_device(_0: Device, /) -> Device: ...
+ def schema_id_dltensor(_0: Tensor, /) -> Tensor: ...
+ def schema_id_dtype(_0: dtype, /) -> dtype: ...
+ def schema_id_float(_0: float, /) -> float: ...
+ def schema_id_func(_0: Callable[..., Any], /) -> Callable[..., Any]: ...
+ def schema_id_func_typed(_0: Callable[[int, float, Callable[..., Any]],
None], /) -> Callable[[int, float, Callable[..., Any]], None]: ...
+ def schema_id_int(_0: int, /) -> int: ...
+ def schema_id_map(_0: Mapping[Any, Any], /) -> Mapping[Any, Any]: ...
+ def schema_id_map_str_int(_0: Mapping[str, int], /) -> Mapping[str, int]:
...
+ def schema_id_map_str_obj(_0: Mapping[str, Object], /) -> Mapping[str,
Object]: ...
+ def schema_id_map_str_str(_0: Mapping[str, str], /) -> Mapping[str, str]:
...
+ def schema_id_object(_0: Object, /) -> Object: ...
+ def schema_id_opt_int(_0: int | None, /) -> int | None: ...
+ def schema_id_opt_obj(_0: Object | None, /) -> Object | None: ...
+ def schema_id_opt_str(_0: str | None, /) -> str | None: ...
+ def schema_id_string(_0: str, /) -> str: ...
+ def schema_id_tensor(_0: Tensor, /) -> Tensor: ...
+ def schema_id_variant_int_str(_0: int | str, /) -> int | str: ...
+ def schema_no_args() -> int: ...
+ def schema_no_args_no_return() -> None: ...
+ def schema_no_return(_0: int, /) -> None: ...
+ def schema_packed(*args: Any) -> Any: ...
+ def schema_tensor_view_input(_0: Tensor, /) -> None: ...
+ def schema_variant_mix(_0: int | str | Sequence[int], /) -> int | str |
Sequence[int]: ...
+ def test_raise_error(_0: str, _1: str, /) -> None: ...
+# fmt: on
+# tvm-ffi-stubgen(end)
+
+__all__ = [
+ # tvm-ffi-stubgen(begin): __all__
+ "TestIntPairSum",
+ "add_one",
+ "apply",
+ "echo",
+ "get_add_one_c_symbol",
+ "get_mlir_add_one_c_symbol",
+ "make_unregistered_object",
+ "nop",
+ "object_use_count",
+ "optional_tensor_view_has_value",
+ "run_check_signal",
+ "schema_arr_map_opt",
+ "schema_id_any",
+ "schema_id_arr",
+ "schema_id_arr_int",
+ "schema_id_arr_obj",
+ "schema_id_arr_str",
+ "schema_id_bool",
+ "schema_id_bytes",
+ "schema_id_device",
+ "schema_id_dltensor",
+ "schema_id_dtype",
+ "schema_id_float",
+ "schema_id_func",
+ "schema_id_func_typed",
+ "schema_id_int",
+ "schema_id_map",
+ "schema_id_map_str_int",
+ "schema_id_map_str_obj",
+ "schema_id_map_str_str",
+ "schema_id_object",
+ "schema_id_opt_int",
+ "schema_id_opt_obj",
+ "schema_id_opt_str",
+ "schema_id_string",
+ "schema_id_tensor",
+ "schema_id_variant_int_str",
+ "schema_no_args",
+ "schema_no_args_no_return",
+ "schema_no_return",
+ "schema_packed",
+ "schema_tensor_view_input",
+ "schema_variant_mix",
+ "test_raise_error",
+ # tvm-ffi-stubgen(end)
+]
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing/testing.py
similarity index 95%
rename from python/tvm_ffi/testing.py
rename to python/tvm_ffi/testing/testing.py
index 2b157b1..b905b5b 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing/testing.py
@@ -17,24 +17,25 @@
"""Testing utilities."""
# ruff: noqa: D102,D105
-# tvm-ffi-stubgen(begin): import
+# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
-from typing import Any, TYPE_CHECKING
+from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from tvm_ffi import Device, Object, dtype
+ from typing import Any
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
from typing import ClassVar
-from . import _ffi_api
-from .core import Object
-from .dataclasses import c_class, field
-from .registry import get_global_func, register_object
+from .. import _ffi_api
+from ..core import Object
+from ..dataclasses import c_class, field
+from ..registry import get_global_func, register_object
@register_object("testing.TestObjectBase")
diff --git a/tests/python/test_stubgen.py b/tests/python/test_stubgen.py
index ef53143..c0d3dd5 100644
--- a/tests/python/test_stubgen.py
+++ b/tests/python/test_stubgen.py
@@ -14,31 +14,53 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from pathlib import Path
import pytest
from tvm_ffi.core import TypeSchema
from tvm_ffi.stub import consts as C
+from tvm_ffi.stub.cli import _stage_3
from tvm_ffi.stub.codegen import (
generate_all,
+ generate_export,
+ generate_ffi_api,
generate_global_funcs,
- generate_imports,
+ generate_import_section,
+ generate_init,
generate_object,
)
from tvm_ffi.stub.file_utils import CodeBlock, FileInfo
-from tvm_ffi.stub.utils import FuncInfo, NamedTypeSchema, ObjectInfo, Options
+from tvm_ffi.stub.utils import (
+ FuncInfo,
+ ImportItem,
+ InitConfig,
+ NamedTypeSchema,
+ ObjectInfo,
+ Options,
+)
def _identity_ty_map(name: str) -> str:
return name
+def _default_ty_map() -> dict[str, str]:
+ return C.TY_MAP_DEFAULTS.copy()
+
+
+def _type_suffix(name: str) -> str:
+ return C.TY_MAP_DEFAULTS.get(name, name).rsplit(".", 1)[-1]
+
+
def test_codeblock_from_begin_line_variants() -> None:
cases = [
- (f"{C.STUB_BEGIN} global/example", "global", "example"),
- (f"{C.STUB_BEGIN} object/testing.TestObjectBase", "object",
"testing.TestObjectBase"),
+ (f"{C.STUB_BEGIN} global/demo", "global", ("demo", "")),
+ (f"{C.STUB_BEGIN} global/[email protected]", "global", ("demo",
".registry")),
+ (f"{C.STUB_BEGIN} object/demo.TypeBase", "object", "demo.TypeBase"),
(f"{C.STUB_BEGIN} ty-map/custom", "ty-map", "custom"),
- (f"{C.STUB_BEGIN} import", "import", ""),
+ (f"{C.STUB_BEGIN} import-section", "import-section", ""),
]
for lineno, (line, kind, param) in enumerate(cases, start=1):
block = CodeBlock.from_begin_line(lineno, line)
@@ -93,7 +115,7 @@ def test_fileinfo_from_file_parses_blocks(tmp_path: Path) ->
None:
assert first.kind is None and first.lines == ["first = 1"]
assert stub.kind == "global"
- assert stub.param == "demo.func"
+ assert stub.param == ("demo.func", "")
assert stub.lineno_start == 2
assert stub.lineno_end == 4
assert stub.lines == [
@@ -212,25 +234,31 @@ def test_objectinfo_gen_fields_and_methods() -> None:
def test_generate_global_funcs_updates_block() -> None:
code = CodeBlock(
kind="global",
- param="testing",
+ param=("demo", "mockpkg"),
lineno_start=1,
lineno_end=2,
- lines=[f"{C.STUB_BEGIN} global/testing", C.STUB_END],
+ lines=[f"{C.STUB_BEGIN} global/demo@mockpkg", C.STUB_END],
)
funcs = [
FuncInfo(
schema=NamedTypeSchema(
- "testing.add_one",
+ "demo.add_one",
TypeSchema("Callable", (TypeSchema("int"), TypeSchema("int"))),
),
is_member=False,
)
]
opts = Options(indent=2)
- generate_global_funcs(code, funcs, _identity_ty_map, opts)
+ imports: list[ImportItem] = []
+ generate_global_funcs(code, funcs, _default_ty_map(), imports, opts)
+ assert imports == [
+ ImportItem("mockpkg.init_ffi_api", alias="_FFI_INIT_FUNC"),
+ ImportItem("typing.TYPE_CHECKING"),
+ ]
assert code.lines == [
- f"{C.STUB_BEGIN} global/testing",
+ f"{C.STUB_BEGIN} global/demo@mockpkg",
"# fmt: off",
+ '_FFI_INIT_FUNC("demo", __name__)',
"if TYPE_CHECKING:",
" def add_one(_0: int, /) -> int: ...",
"# fmt: on",
@@ -241,31 +269,71 @@ def test_generate_global_funcs_updates_block() -> None:
def test_generate_global_funcs_noop_on_empty_list() -> None:
code = CodeBlock(
kind="global",
- param="empty",
+ param=("empty", ""),
lineno_start=1,
lineno_end=2,
lines=[f"{C.STUB_BEGIN} global/empty", C.STUB_END],
)
- generate_global_funcs(code, [], _identity_ty_map, Options())
+ imports: list[ImportItem] = []
+ generate_global_funcs(code, [], _default_ty_map(), imports, Options())
assert code.lines == [f"{C.STUB_BEGIN} global/empty", C.STUB_END]
+ assert imports == []
+
+
+def test_generate_global_funcs_respects_custom_import_from() -> None:
+ code = CodeBlock(
+ kind="global",
+ param=("demo", "custom.mod"),
+ lineno_start=1,
+ lineno_end=2,
+ lines=[f"{C.STUB_BEGIN} global/[email protected]", C.STUB_END],
+ )
+ funcs = [
+ FuncInfo(
+ schema=NamedTypeSchema(
+ "demo.add_one",
+ TypeSchema("Callable", (TypeSchema("int"), TypeSchema("int"))),
+ ),
+ is_member=False,
+ )
+ ]
+ imports: list[ImportItem] = []
+ generate_global_funcs(code, funcs, _default_ty_map(), imports,
Options(indent=0))
+ assert ImportItem("custom.mod.init_ffi_api", alias="_FFI_INIT_FUNC") in
imports
def test_generate_object_fields_only_block() -> None:
code = CodeBlock(
kind="object",
- param="testing.TestObjectDerived",
+ param="demo.TypeDerived",
lineno_start=1,
lineno_end=2,
- lines=[f"{C.STUB_BEGIN} object/testing.TestObjectDerived", C.STUB_END],
+ lines=[f"{C.STUB_BEGIN} object/demo.TypeDerived", C.STUB_END],
)
opts = Options(indent=4)
- generate_object(code, _identity_ty_map, opts)
+ imports: list[ImportItem] = []
+ info = ObjectInfo(
+ fields=[
+ NamedTypeSchema("field_a", TypeSchema("int")),
+ NamedTypeSchema("field_b", TypeSchema("float")),
+ ],
+ methods=[],
+ type_key="demo.TypeDerived",
+ parent_type_key="demo.Parent",
+ )
+ generate_object(
+ code,
+ _default_ty_map(),
+ imports,
+ opts,
+ info,
+ )
+ assert imports == []
- info = ObjectInfo.from_type_key("testing.TestObjectDerived")
expected = [
- f"{C.STUB_BEGIN} object/testing.TestObjectDerived",
+ f"{C.STUB_BEGIN} object/demo.TypeDerived",
" " * code.indent + "# fmt: off",
- *[(" " * code.indent) + line for line in
info.gen_fields(_identity_ty_map, indent=0)],
+ *[(" " * code.indent) + line for line in info.gen_fields(_type_suffix,
indent=0)],
" " * code.indent + "# fmt: on",
C.STUB_END,
]
@@ -275,15 +343,34 @@ def test_generate_object_fields_only_block() -> None:
def test_generate_object_with_methods() -> None:
code = CodeBlock(
kind="object",
- param="testing.TestIntPair",
+ param="demo.IntPair",
lineno_start=1,
lineno_end=2,
- lines=[f"{C.STUB_BEGIN} object/testing.TestIntPair", C.STUB_END],
+ lines=[f"{C.STUB_BEGIN} object/demo.IntPair", C.STUB_END],
)
opts = Options(indent=4)
- generate_object(code, _identity_ty_map, opts)
+ imports: list[ImportItem] = []
+ info = ObjectInfo(
+ fields=[],
+ methods=[
+ FuncInfo.from_schema(
+ "demo.IntPair.__c_ffi_init__",
+ TypeSchema("Callable", (TypeSchema("None"), TypeSchema("int"),
TypeSchema("int"))),
+ is_member=True,
+ ),
+ FuncInfo.from_schema(
+ "demo.IntPair.sum",
+ TypeSchema("Callable", (TypeSchema("int"),)),
+ is_member=True,
+ ),
+ ],
+ type_key="demo.IntPair",
+ parent_type_key="demo.Parent",
+ )
+ generate_object(code, _default_ty_map(), imports, opts, info)
+ assert set(imports) == {ImportItem("typing.TYPE_CHECKING")}
- assert code.lines[0] == f"{C.STUB_BEGIN} object/testing.TestIntPair"
+ assert code.lines[0] == f"{C.STUB_BEGIN} object/demo.IntPair"
assert code.lines[-1] == C.STUB_END
assert "# fmt: off" in code.lines[1]
assert any("if TYPE_CHECKING:" in line for line in code.lines)
@@ -294,38 +381,52 @@ def test_generate_object_with_methods() -> None:
assert any(line.strip().startswith("def sum") for line in method_lines)
-def test_generate_imports_groups_modules() -> None:
+def test_generate_import_section_groups_modules() -> None:
code = CodeBlock(
- kind="import",
+ kind="import-section",
param="",
lineno_start=1,
lineno_end=2,
lines=[f"{C.STUB_BEGIN} import", C.STUB_END],
)
- ty_used = {
- "typing.Any",
- "tvm_ffi.Tensor",
- "testing.TestObjectBase",
- "custom.mod.Type",
- }
+ imports = [
+ ImportItem("typing.Any", type_checking_only=True),
+ ImportItem("demo_pkg.Tensor", type_checking_only=True),
+ ImportItem("demo.TestObjectBase", type_checking_only=True),
+ ImportItem("custom.mod.Type", type_checking_only=True),
+ ]
opts = Options(indent=4)
- generate_imports(code, ty_used, opts)
+ generate_import_section(code, imports, opts)
expected_prefix = [
f"{C.STUB_BEGIN} import",
"# fmt: off",
"# isort: off",
"from __future__ import annotations",
- "from typing import Any, TYPE_CHECKING",
+ "from typing import TYPE_CHECKING",
"if TYPE_CHECKING:",
]
assert code.lines[: len(expected_prefix)] == expected_prefix
- assert " from tvm_ffi.testing import TestObjectBase" in code.lines
- assert " from tvm_ffi import Tensor" in code.lines
+ assert " from demo import TestObjectBase" in code.lines
+ assert " from demo_pkg import Tensor" in code.lines
assert " from custom.mod import Type" in code.lines
+ assert " from typing import Any" in code.lines
assert code.lines[-2:] == ["# fmt: on", C.STUB_END]
+def test_generate_import_section_no_imports_noop() -> None:
+ code = CodeBlock(
+ kind="import-section",
+ param="",
+ lineno_start=1,
+ lineno_end=2,
+ lines=[f"{C.STUB_BEGIN} import", C.STUB_END],
+ )
+ before = list(code.lines)
+ generate_import_section(code, [], Options())
+ assert code.lines == before
+
+
def test_generate_all_builds_sorted_and_deduped_list() -> None:
code = CodeBlock(
kind="global",
@@ -359,3 +460,147 @@ def test_generate_all_noop_on_empty_names() -> None:
before = list(code.lines)
generate_all(code, names=set(), opt=Options())
assert code.lines == before
+
+
+def test_generate_all_uses_isort_style_ordering() -> None:
+ code = CodeBlock(
+ kind="global",
+ param="all-mixed",
+ lineno_start=1,
+ lineno_end=2,
+ lines=[C.STUB_BEGIN + " global/all-mixed", C.STUB_END],
+ )
+ names = {"foo", "Bar", "LIB", "baz", "Alpha", "CONST"}
+ generate_all(code, names=names, opt=Options(indent=0))
+ assert code.lines == [
+ C.STUB_BEGIN + " global/all-mixed",
+ '"CONST",',
+ '"LIB",',
+ '"Alpha",',
+ '"Bar",',
+ '"baz",',
+ '"foo",',
+ C.STUB_END,
+ ]
+
+
+def test_stage_3_adds_LIB_when_load_lib_imported(tmp_path: Path) -> None:
+ path = tmp_path / "demo.py"
+ global_block = CodeBlock(
+ kind="global",
+ param=("testing", ""),
+ lineno_start=2,
+ lineno_end=3,
+ lines=[f"{C.STUB_BEGIN} global/testing", C.STUB_END],
+ )
+ import_obj_block = CodeBlock(
+ kind="import-object",
+ param=("tvm_ffi.libinfo.load_lib_module", "False", "_FFI_LOAD_LIB"),
+ lineno_start=1,
+ lineno_end=1,
+ lines=[f"{C.STUB_IMPORT_OBJECT}
tvm_ffi.libinfo.load_lib_module;False;_FFI_LOAD_LIB"],
+ )
+ all_block = CodeBlock(
+ kind="__all__",
+ param="",
+ lineno_start=4,
+ lineno_end=5,
+ lines=[f"{C.STUB_BEGIN} __all__", C.STUB_END],
+ )
+ file_info = FileInfo(
+ path=path,
+ lines=tuple(
+ line for block in (import_obj_block, global_block, all_block) for
line in block.lines
+ ),
+ code_blocks=[import_obj_block, global_block, all_block],
+ )
+ funcs = [
+ FuncInfo.from_schema(
+ "testing.add_one",
+ TypeSchema("Callable", (TypeSchema("int"), TypeSchema("int"))),
+ )
+ ]
+ _stage_3(
+ file_info,
+ Options(dry_run=True),
+ _default_ty_map(),
+ {"testing": funcs},
+ )
+ lib_lines = [line for line in all_block.lines if "LIB" in line]
+ assert any("LIB" in line for line in lib_lines)
+
+
+def test_generate_export_builds_all_extension() -> None:
+ code = CodeBlock(
+ kind="export",
+ param="ffi_api",
+ lineno_start=1,
+ lineno_end=2,
+ lines=[f"{C.STUB_BEGIN} export/ffi_api", C.STUB_END],
+ )
+ generate_export(code)
+ full_text = "\n".join(code.lines)
+ assert "from .ffi_api import *" in full_text
+ assert "ffi_api__all__" in full_text
+
+
+def test_generate_init_with_and_without_existing_export_block() -> None:
+ code_no_blocks = generate_init([], "demo")
+ assert "Package demo." in code_no_blocks
+ assert f"{C.STUB_BEGIN} export/_ffi_api" in code_no_blocks
+
+ code_with_export = generate_init(
+ [
+ CodeBlock(
+ kind="export",
+ param="_ffi_api",
+ lineno_start=1,
+ lineno_end=2,
+ lines=["", ""],
+ )
+ ],
+ "demo",
+ )
+ assert code_with_export == ""
+
+
+def test_generate_ffi_api_without_objects_includes_sections() -> None:
+ init_cfg = InitConfig(pkg="pkg", shared_target="pkg_shared", prefix="pkg.")
+ code = generate_ffi_api(
+ [],
+ _default_ty_map(),
+ "demo.mod",
+ [],
+ init_cfg,
+ is_root=False,
+ )
+ assert f"{C.STUB_BEGIN} import-section" in code
+ assert f"{C.STUB_BEGIN} global/demo.mod" in code
+ assert C.STUB_BEGIN + " __all__" in code
+ assert "LIB =" not in code
+
+
+def test_generate_ffi_api_with_objects_imports_parents() -> None:
+ init_cfg = InitConfig(pkg="pkg", shared_target="pkg_shared", prefix="pkg.")
+ obj_info = ObjectInfo(
+ fields=[],
+ methods=[],
+ type_key="demo.TypeDerived",
+ parent_type_key="demo.Parent",
+ )
+ parent_key = obj_info.parent_type_key
+ code = generate_ffi_api(
+ [],
+ _default_ty_map(),
+ "demo",
+ [obj_info],
+ init_cfg,
+ is_root=False,
+ )
+ assert C.STUB_IMPORT_OBJECT in code # register_object prompt
+ assert f"{C.STUB_BEGIN} object/{obj_info.type_key}" in code
+ assert parent_key is not None
+ parent_import_prompt = (
+ f"{C.STUB_IMPORT_OBJECT} {parent_key};False;_{parent_key.replace('.',
'_')}"
+ )
+ assert parent_import_prompt in code