This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new bb2e4586 [PY] Distinguish FFI input and output annotations (#621)
bb2e4586 is described below
commit bb2e4586fff47ed75331b44639b1fb0ffb6f1e50
Author: Yixin Dong <[email protected]>
AuthorDate: Tue Jun 23 10:23:26 2026 -0400
[PY] Distinguish FFI input and output annotations (#621)
## Summary
- Add separate input/output rendering APIs for `TypeSchema` so stubgen
can widen callable and constructor inputs without weakening returned or
field types.
- Add a `__ffi_convert_type_schema__` TypeAttr hook for object
conversions and wire `py_class`/`c_class` dataclass transform converter
metadata.
- Generate `type_schema` metadata for Python-defined `@method`
TypeMethods so reflection-driven stubgen can render their signatures.
- Cover container asymmetry, object convert schema recursion, stubgen
signatures, partial type maps, `@method` metadata, and dataclass
metadata with regression tests.
- Update the PR CI workflow to avoid Apache-disallowed third-party setup
actions and run pre-commit without syncing the project during lint.
## Validation
- pre-commit hooks during commits, including ruff, ty, clang-format,
cython-lint, CMake lint/format, and workflow YAML checks
- `python -m pytest
tests/python/test_stubgen.py::test_py_class_method_metadata_renders_stub_signature
tests/python/test_typed_method.py -q`
- `python -m pytest tests/python/test_stubgen.py -q`
- `CUDA_VISIBLE_DEVICES="" python -m pytest tests/python -q` (`2322
passed, 18 skipped, 2 xfailed`)
- Manual output dumped under
`tmp/2026-06-14-update-ffi/manual_check.out`
- GitHub CI is green for lint, docs, Ubuntu, Ubuntu ARM, macOS, and
Windows
---
CMakeLists.txt | 3 +
cmake/Utils/AddGoogleTest.cmake | 2 +-
include/tvm/ffi/reflection/accessor.h | 8 ++
python/tvm_ffi/core.pyi | 2 +
python/tvm_ffi/cython/type_info.pxi | 69 ++++++++++--
python/tvm_ffi/dataclasses/c_class.py | 9 +-
python/tvm_ffi/dataclasses/field.py | 19 +++-
python/tvm_ffi/dataclasses/py_class.py | 81 ++++++++++++--
python/tvm_ffi/stub/python_generator/codegen.py | 30 +++++-
python/tvm_ffi/stub/python_generator/consts.py | 7 ++
python/tvm_ffi/stub/python_generator/utils.py | 44 +++++---
tests/python/test_dataclass_c_class.py | 9 +-
tests/python/test_dataclass_py_class.py | 14 +++
tests/python/test_function.py | 16 +--
tests/python/test_stubgen.py | 137 ++++++++++++++++++++++++
tests/python/test_type_converter.py | 51 ++++++++-
16 files changed, 452 insertions(+), 49 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d306bb2a..d0cb8f08 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -247,10 +247,13 @@ if (TVM_FFI_BUILD_PYTHON_MODULE)
set(_cython_sources
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/core.pyx
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/base.pxi
+ ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/type_info.pxi
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/device.pxi
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/dtype.pxi
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/error.pxi
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/function.pxi
+ ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/pycallback.pxi
+
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/pyclass_type_converter.pxi
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/tensor.pxi
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/object.pxi
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/string.pxi
diff --git a/cmake/Utils/AddGoogleTest.cmake b/cmake/Utils/AddGoogleTest.cmake
index cd8d0931..8b4800da 100644
--- a/cmake/Utils/AddGoogleTest.cmake
+++ b/cmake/Utils/AddGoogleTest.cmake
@@ -91,7 +91,7 @@ macro (TVM_FFI_ADD_GTEST target_name)
target_link_libraries(${target_name} PRIVATE gtest_main)
gtest_discover_tests(
${target_name}
- WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} TEST_DISCOVERY_TIMEOUT 600
DISCOVERY_MODE PRE_TEST
+ WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} DISCOVERY_TIMEOUT 600
DISCOVERY_MODE PRE_TEST
PROPERTIES VS_DEBUGGER_WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}"
)
set_target_properties(${target_name} PROPERTIES FOLDER tests)
diff --git a/include/tvm/ffi/reflection/accessor.h
b/include/tvm/ffi/reflection/accessor.h
index 80de4931..d0aa734c 100644
--- a/include/tvm/ffi/reflection/accessor.h
+++ b/include/tvm/ffi/reflection/accessor.h
@@ -350,6 +350,14 @@ inline constexpr const char* kInit = "__ffi_init__";
* Signature: ``(AnyView src) -> TSelf``, where ``TSelf`` is a subclass of
ObjectRef.
*/
inline constexpr const char* kConvert = "__ffi_convert__";
+/*!
+ * \brief Type schema accepted by ``kConvert`` before it returns ``TSelf``.
+ *
+ * Stored as a JSON type schema string or schema-like object. Python stub
+ * generation uses this attribute to render widened input annotations while
+ * keeping output annotations precise.
+ */
+inline constexpr const char* kConvertTypeSchema =
"__ffi_convert_type_schema__";
/*!
* \brief Shallow-copy factory.
*
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index c6512365..59e68ddb 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -274,6 +274,8 @@ class TypeSchema:
@staticmethod
def from_annotation(annotation: object) -> TypeSchema: ...
def repr(self, ty_map: Callable[[str], str] | None = None) -> str: ...
+ def input_repr(self, ty_map: Callable[[str], str] | None = None) -> str:
...
+ def output_repr(self, ty_map: Callable[[str], str] | None = None) -> str:
...
def check_value(self, value: object) -> None: ...
def convert(self, value: object) -> CAny: ...
def to_json(self) -> dict[str, Any]: ...
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 960d5cd5..f14fb6b8 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -116,6 +116,8 @@ _TYPE_SCHEMA_ORIGIN_CONVERTER = {
"ObjectRValueRef": "Object",
}
+cdef str _TYPE_ATTR_FFI_CONVERT_TYPE_SCHEMA = "__ffi_convert_type_schema__"
+
# Sentinel for structural types (Optional, Union) that have no single
type_index
_ORIGIN_TYPE_INDEX_STRUCTURAL = -2
# Sentinel for unknown/unresolved origins
@@ -217,6 +219,19 @@ class TypeSchema:
def __repr__(self) -> str:
return self.repr(ty_map=None)
+ @staticmethod
+ def _from_maybe_json(raw: object) -> "TypeSchema":
+ """Construct a TypeSchema from a TypeSchema, JSON string, or JSON
dict."""
+ if isinstance(raw, TypeSchema):
+ return raw
+ if isinstance(raw, str):
+ return TypeSchema.from_json_str(raw)
+ if isinstance(raw, dict):
+ return TypeSchema.from_json_obj(raw)
+ raise TypeError(
+ f"expected TypeSchema, JSON string, or JSON dict, got
{type(raw).__name__}"
+ )
+
@staticmethod
def from_json_obj(obj: dict[str, Any]) -> "TypeSchema":
"""Construct a :class:`TypeSchema` from a parsed JSON object.
@@ -517,12 +532,40 @@ class TypeSchema:
assert s.repr() == "Array[int]"
"""
- if ty_map is None:
- origin = self.origin
- else:
- origin = ty_map(self.origin)
+ return self.output_repr(ty_map)
+
+ def input_repr(self, ty_map: "Optional[Callable[[str], str]]" = None) ->
str:
+ """Render the Python input annotation accepted by this schema."""
+ return self._repr_impl(ty_map, input_mode=True,
expanded_convert_types=frozenset())
+
+ def output_repr(self, ty_map: "Optional[Callable[[str], str]]" = None) ->
str:
+ """Render the precise Python output annotation produced by this
schema."""
+ return self._repr_impl(ty_map, input_mode=False,
expanded_convert_types=frozenset())
+
+ def _repr_impl(
+ self,
+ ty_map: "Optional[Callable[[str], str]]",
+ input_mode: bool,
+ expanded_convert_types: "frozenset[int]",
+ ) -> str:
+ if input_mode and self.origin_type_index >= kTVMFFIStaticObjectBegin:
+ if self.origin_type_index not in expanded_convert_types:
+ raw_convert_schema = _lookup_type_attr(
+ self.origin_type_index, _TYPE_ATTR_FFI_CONVERT_TYPE_SCHEMA
+ )
+ if raw_convert_schema is not None:
+ return
TypeSchema._from_maybe_json(raw_convert_schema)._repr_impl(
+ ty_map,
+ input_mode=True,
+ expanded_convert_types=expanded_convert_types |
{self.origin_type_index},
+ )
+
+ origin = self.origin if ty_map is None else ty_map(self.origin)
schema_args = self.args
- args = [i.repr(ty_map) for i in (() if schema_args is None else
schema_args)]
+ args = [
+ i._repr_impl(ty_map, input_mode, expanded_convert_types)
+ for i in (() if schema_args is None else schema_args)
+ ]
if origin == "Union":
return " | ".join(args)
elif origin == "Optional":
@@ -1129,8 +1172,8 @@ cdef _register_py_methods(int32_t type_index, list
py_methods, frozenset type_at
----------
type_index : int
The runtime type index of the type.
- py_methods : list[tuple[str, Any, bool]]
- Each entry is ``(name, value, is_static)``.
+ py_methods : list[tuple[str, Any, bool, str | None]]
+ Each entry is ``(name, value, is_static, metadata_json)``.
type_attr_names : frozenset[str]
Names to register as TypeAttrColumn instead of TypeMethod.
"""
@@ -1139,11 +1182,12 @@ cdef _register_py_methods(int32_t type_index, list
py_methods, frozenset type_at
cdef TVMFFIAny sentinel_any
cdef int c_api_ret_code
cdef ByteArrayArg name_arg
+ cdef ByteArrayArg metadata_arg
sentinel_any.type_index = kTVMFFINone
sentinel_any.v_int64 = 0
- for name, func, is_static in py_methods:
+ for name, func, is_static, metadata_json in py_methods:
func_any.type_index = kTVMFFINone
func_any.v_int64 = 0
try:
@@ -1169,8 +1213,13 @@ cdef _register_py_methods(int32_t type_index, list
py_methods, frozenset type_at
method_info.doc.size = 0
method_info.flags = kTVMFFIFieldFlagBitMaskIsStaticMethod if
is_static else 0
method_info.method = func_any
- method_info.metadata.data = NULL
- method_info.metadata.size = 0
+ if metadata_json is not None:
+ metadata_bytes = c_str(metadata_json)
+ metadata_arg = ByteArrayArg(metadata_bytes)
+ method_info.metadata = metadata_arg.cdata
+ else:
+ method_info.metadata.data = NULL
+ method_info.metadata.size = 0
CHECK_CALL(TVMFFITypeRegisterMethod(type_index, &method_info))
finally:
if func_any.type_index >= kTVMFFIStaticObjectBegin and
func_any.v_obj != NULL:
diff --git a/python/tvm_ffi/dataclasses/c_class.py
b/python/tvm_ffi/dataclasses/c_class.py
index 501c56b0..2799669b 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -24,7 +24,7 @@ from typing import Any, TypeVar
from typing_extensions import dataclass_transform
-from .field import Field
+from .field import Field, _field_converter, field
_T = TypeVar("_T", bound=type)
@@ -60,7 +60,12 @@ def _attach_field_objects(cls: type, type_info: Any) -> None:
tf.dataclass_field = f
-@dataclass_transform(eq_default=False, order_default=False)
+@dataclass_transform(
+ eq_default=False,
+ order_default=False,
+ field_specifiers=(field, Field),
+ converter=_field_converter,
+)
def c_class(
type_key: str,
*,
diff --git a/python/tvm_ffi/dataclasses/field.py
b/python/tvm_ffi/dataclasses/field.py
index 08c2ea3f..ab7dd66f 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -37,6 +37,11 @@ else:
"""Sentinel type: annotations after ``_: KW_ONLY`` are keyword-only."""
+def _field_converter(value: Any) -> Any:
+ """Static-analysis marker for fields whose values are converted by FFI."""
+ return value
+
+
class Field:
"""Descriptor for a single field in a Python-defined TVM-FFI type.
@@ -102,12 +107,16 @@ class Field:
parameters that reference outer-scope vars.
doc : str | None
Optional docstring for the field.
+ converter : Callable[[Any], Any]
+ Static-analysis marker for field conversion. Runtime conversion is
+ still handled by the FFI type converter.
"""
__slots__ = (
"_ty_schema",
"compare",
+ "converter",
"default",
"default_factory",
"doc",
@@ -130,6 +139,7 @@ class Field:
repr: bool
hash: bool | None
compare: bool
+ converter: Callable[[Any], Any]
kw_only: bool | None
structural_eq: str | None
doc: str | None
@@ -158,6 +168,7 @@ class Field:
kw_only: bool | None = False,
structural_eq: str | None = None,
doc: str | None = None,
+ converter: Callable[[Any], Any] = _field_converter,
) -> None:
# MISSING means "parameter not provided".
# An explicit None from the user fails the callable() check,
@@ -185,12 +196,13 @@ class Field:
self.repr = repr
self.hash = hash
self.compare = compare
+ self.converter = converter
self.kw_only = kw_only
self.structural_eq = structural_eq
self.doc = doc
-def field(
+def field( # noqa: PLR0913
*,
default: object = MISSING,
default_factory: Callable[[], object] | None = MISSING, # type:
ignore[assignment]
@@ -202,6 +214,7 @@ def field(
kw_only: bool | None = None,
structural_eq: str | None = None,
doc: str | None = None,
+ converter: Callable[[Any], Any] = _field_converter,
) -> Any:
"""Customize a field in a ``@py_class``-decorated class.
@@ -248,6 +261,9 @@ def field(
binding.
doc
Optional docstring for the field.
+ converter
+ Static-analysis marker for field conversion. Runtime conversion is
+ still handled by the FFI type converter.
Returns
-------
@@ -282,4 +298,5 @@ def field(
kw_only=kw_only,
structural_eq=structural_eq,
doc=doc,
+ converter=converter,
)
diff --git a/python/tvm_ffi/dataclasses/py_class.py
b/python/tvm_ffi/dataclasses/py_class.py
index bb2c8a1f..28e8f957 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -18,6 +18,8 @@
from __future__ import annotations
+import inspect
+import json
import sys
import typing
from collections.abc import Callable
@@ -31,7 +33,7 @@ from .. import core
from .._dunder import _install_dataclass_dunders
from ..core import MISSING, TypeSchema
from ..registry import _add_class_attrs
-from .field import KW_ONLY, Field, field
+from .field import KW_ONLY, Field, _field_converter, field
_T = TypeVar("_T", bound=type)
@@ -79,6 +81,11 @@ class _PendingClass:
#: :func:`_flush_pending`.
_PENDING_CLASSES: list[_PendingClass] = []
+
+class _DeferredAnnotation(Exception):
+ """Raised when method annotations must wait for a later py_class
registration."""
+
+
#: Per-module mapping of ``class.__name__ → class`` for every
#: ``@py_class``-decorated type. Used as *localns* when resolving
#: annotations so that mutual references between classes in the same
@@ -328,7 +335,52 @@ def _validate_method_name(cls: type, name: str) -> None:
)
-def _collect_py_methods(cls: type) -> list[tuple[str, Any, bool]] | None:
+def _method_type_schema_json(
+ cls: type,
+ func: Any,
+ is_static: bool,
+ globalns: dict[str, Any],
+) -> str:
+ """Build reflection metadata for a Python-defined FFI TypeMethod."""
+ kwargs: dict[str, Any] = {"globalns": globalns, "localns":
_build_localns(cls)}
+ if sys.version_info >= (3, 11):
+ kwargs["include_extras"] = True
+ try:
+ hints = typing.get_type_hints(func, **kwargs)
+ except (NameError, AttributeError):
+ kwargs["localns"] = _build_localns(cls, cross_module=True)
+ try:
+ hints = typing.get_type_hints(func, **kwargs)
+ except (NameError, AttributeError) as err:
+ raise _DeferredAnnotation from err
+
+ sig = inspect.signature(func)
+ params = list(sig.parameters.values())
+ if any(
+ param.kind in (inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD)
+ for param in params
+ ):
+ return json.dumps({"type_schema": TypeSchema("Callable").to_json()})
+
+ arg_schemas: list[TypeSchema] = []
+ if not is_static:
+ arg_schemas.append(TypeSchema.from_annotation(cls))
+ params = params[1:]
+
+ for param in params:
+ if param.kind is inspect.Parameter.KEYWORD_ONLY:
+ return json.dumps({"type_schema":
TypeSchema("Callable").to_json()})
+ annotation = hints.get(param.name, Any)
+ arg_schemas.append(TypeSchema.from_annotation(annotation))
+
+ ret_annotation = hints.get("return", Any)
+ ret_schema = TypeSchema.from_annotation(ret_annotation)
+ return json.dumps({"type_schema": TypeSchema("Callable", (ret_schema,
*arg_schemas)).to_json()})
+
+
+def _collect_py_methods(
+ cls: type, globalns: dict[str, Any] | None = None
+) -> list[tuple[Any, ...]] | None:
"""Extract FFI-registered entries from a ``@py_class`` body.
Two sources are collected:
@@ -349,10 +401,14 @@ def _collect_py_methods(cls: type) -> list[tuple[str,
Any, bool]] | None:
and Python protocol dunders cannot be ``@method``-decorated; those
are reserved by the TypeAttrColumn and Python semantics respectively.
- Returns the ``(name, value, is_static)`` list, or :data:`None` when
- no entries were found.
+ Returns the ``(name, value, is_static, metadata_json)`` list, or
+ :data:`None` when no entries were found.
"""
- methods: list[tuple[str, Any, bool]] = []
+ legacy_shape = globalns is None
+ if globalns is None:
+ globalns = vars(sys.modules[cls.__module__])
+
+ methods: list[tuple[Any, ...]] = []
for name, value in cls.__dict__.items():
marked = _is_method_marked(value)
if name not in _FFI_RECOGNIZED_METHODS and not marked:
@@ -374,7 +430,13 @@ def _collect_py_methods(cls: type) -> list[tuple[str, Any,
bool]] | None:
)
is_static = isinstance(value, staticmethod)
func = value.__func__ if is_static else value
- methods.append((name, func, is_static))
+ metadata_json = None
+ if marked:
+ metadata_json = _method_type_schema_json(cls, func, is_static,
globalns)
+ if legacy_shape:
+ methods.append((name, func, is_static))
+ else:
+ methods.append((name, func, is_static, metadata_json))
return methods if methods else None
@@ -447,7 +509,10 @@ def _register_fields_into_type(
assert f.name is not None
fields_map[f.name] = f
own_fields = list(fields_map.values())
- py_methods = _collect_py_methods(cls)
+ try:
+ py_methods = _collect_py_methods(cls, globalns)
+ except _DeferredAnnotation:
+ return False
# Register fields and type-level structural eq/hash kind with the C layer.
structure_kind = _STRUCTURE_KIND_MAP.get(params.get("structural_eq"))
@@ -603,6 +668,7 @@ _FFI_TYPE_ATTR_NAMES: frozenset[str] = frozenset(
"__ffi_eq__",
"__ffi_compare__",
"__ffi_convert__",
+ "__ffi_convert_type_schema__",
"__any_hash__",
"__any_equal__",
"__s_equal__",
@@ -627,6 +693,7 @@ _FFI_RECOGNIZED_METHODS: frozenset[str] =
_FFI_TYPE_ATTR_NAMES
eq_default=False,
order_default=False,
field_specifiers=(field, Field),
+ converter=_field_converter,
)
def py_class( # noqa: PLR0913
cls_or_type_key: type | str | None = None,
diff --git a/python/tvm_ffi/stub/python_generator/codegen.py
b/python/tvm_ffi/stub/python_generator/codegen.py
index c3ffe76d..0f813464 100644
--- a/python/tvm_ffi/stub/python_generator/codegen.py
+++ b/python/tvm_ffi/stub/python_generator/codegen.py
@@ -30,6 +30,7 @@ from typing import Callable
from .. import consts as C
from ..file_utils import CodeBlock
from ..utils import FuncInfo, InitConfig, ObjectInfo, Options
+from . import consts as PC
from .utils import (
ImportItem,
render_func_signature,
@@ -103,6 +104,16 @@ def _type_suffix_and_record(
return _run
+def _make_input_ty_map(ty_map: dict[str, str]) -> dict[str, str]:
+ """Derive input-side defaults without overriding explicit ty-map
entries."""
+ input_ty_map = ty_map.copy()
+ for key, input_default in PC.TY_MAP_INPUT_DEFAULTS.items():
+ output_default = PC.TY_MAP_DEFAULTS.get(key)
+ if ty_map.get(key, output_default) == output_default:
+ input_ty_map[key] = input_default
+ return input_ty_map
+
+
def generate_python_global_funcs(
code: CodeBlock,
global_funcs: list[FuncInfo],
@@ -136,11 +147,17 @@ def generate_python_global_funcs(
)
func_names = {f.schema.name.rsplit(".", 1)[-1] for f in global_funcs}
fn_ty_map = _type_suffix_and_record(ty_map, imports, func_names=func_names)
+ input_fn_ty_map = _type_suffix_and_record(
+ _make_input_ty_map(ty_map), imports, func_names=func_names
+ )
results: list[str] = [
"# fmt: off",
f'_FFI_INIT_FUNC("{prefix}", __name__)',
"if TYPE_CHECKING:",
- *[render_func_signature(func, fn_ty_map, opt.indent) for func in
global_funcs],
+ *[
+ render_func_signature(func, fn_ty_map, opt.indent,
input_ty_map=input_fn_ty_map)
+ for func in global_funcs
+ ],
"# fmt: on",
]
indent = " " * code.indent
@@ -166,12 +183,17 @@ def generate_python_object(
info = obj_info
method_names = {m.schema.name.rsplit(".", 1)[-1] for m in info.methods}
fn_ty_map = _type_suffix_and_record(ty_map, imports,
func_names=method_names)
- init_lines = render_object_init(info, fn_ty_map, opt.indent)
- ffi_init_lines = render_object_ffi_init(info, fn_ty_map, opt.indent)
+ input_fn_ty_map = _type_suffix_and_record(
+ _make_input_ty_map(ty_map), imports, func_names=method_names
+ )
+ init_lines = render_object_init(info, fn_ty_map, opt.indent,
input_ty_map=input_fn_ty_map)
+ ffi_init_lines = render_object_ffi_init(
+ info, fn_ty_map, opt.indent, input_ty_map=input_fn_ty_map
+ )
type_checking_lines = [
*init_lines,
*ffi_init_lines,
- *render_object_methods(info, fn_ty_map, opt.indent),
+ *render_object_methods(info, fn_ty_map, opt.indent,
input_ty_map=input_fn_ty_map),
]
if type_checking_lines:
imports.append(
diff --git a/python/tvm_ffi/stub/python_generator/consts.py
b/python/tvm_ffi/stub/python_generator/consts.py
index a788bc68..5f931841 100644
--- a/python/tvm_ffi/stub/python_generator/consts.py
+++ b/python/tvm_ffi/stub/python_generator/consts.py
@@ -38,6 +38,13 @@ TY_MAP_DEFAULTS = {
"Device": "ffi.Device",
}
+TY_MAP_INPUT_DEFAULTS = {
+ "Array": "collections.abc.Sequence",
+ "List": "collections.abc.Sequence",
+ "Map": "collections.abc.Mapping",
+ "Dict": "collections.abc.Mapping",
+}
+
# TODO(@junrushao): Make it configurable
#: Module-prefix rewrites applied when constructing a Python ``import`` path.
MOD_MAP = {
diff --git a/python/tvm_ffi/stub/python_generator/utils.py
b/python/tvm_ffi/stub/python_generator/utils.py
index 02c0f7b4..37eff8f1 100644
--- a/python/tvm_ffi/stub/python_generator/utils.py
+++ b/python/tvm_ffi/stub/python_generator/utils.py
@@ -102,8 +102,11 @@ def render_func_signature(
func: FuncInfo,
ty_map: Callable[[str], str],
indent: int,
+ input_ty_map: Callable[[str], str] | None = None,
) -> str:
"""Render a function signature string for ``func``."""
+ if input_ty_map is None:
+ input_ty_map = ty_map
func_name = func.schema.name.rsplit(".", 1)[-1]
buf = StringIO()
buf.write(" " * indent)
@@ -121,12 +124,12 @@ def render_func_signature(
buf.write("self, ")
else:
buf.write(f"_{i}: ")
- buf.write(arg.repr(ty_map))
+ buf.write(arg.input_repr(input_ty_map))
buf.write(", ")
if arg_args:
buf.write("/")
buf.write(") -> ")
- buf.write(arg_ret.repr(ty_map))
+ buf.write(arg_ret.output_repr(ty_map))
buf.write(": ...")
return buf.getvalue()
@@ -138,15 +141,18 @@ def render_object_fields(
) -> list[str]:
"""Render field definitions for ``info``."""
indent_str = " " * indent
- return [f"{indent_str}{field.name}: {field.repr(ty_map)}" for field in
info.fields]
+ return [f"{indent_str}{field.name}: {field.output_repr(ty_map)}" for field
in info.fields]
def render_object_methods(
info: ObjectInfo,
ty_map: Callable[[str], str],
indent: int,
+ input_ty_map: Callable[[str], str] | None = None,
) -> list[str]:
"""Render method definitions for ``info``."""
+ if input_ty_map is None:
+ input_ty_map = ty_map
indent_str = " " * indent
ret = []
for method in info.methods:
@@ -154,11 +160,11 @@ def render_object_methods(
if func_name == "__ffi_init__":
# __ffi_init__ is installed as an instance method (self, *args,
**kwargs) -> None
# by _install_ffi_init_attr, regardless of the C++ static
registration.
- ret.append(_render_ffi_init_from_method(method, ty_map, indent))
+ ret.append(_render_ffi_init_from_method(method, ty_map, indent,
input_ty_map))
continue
if not method.is_member:
ret.append(f"{indent_str}@staticmethod")
- ret.append(render_func_signature(method, ty_map, indent))
+ ret.append(render_func_signature(method, ty_map, indent, input_ty_map))
return ret
@@ -166,8 +172,11 @@ def _render_ffi_init_from_method(
method: FuncInfo,
ty_map: Callable[[str], str],
indent: int,
+ input_ty_map: Callable[[str], str] | None = None,
) -> str:
"""Render ``__ffi_init__`` TypeMethod as an instance method returning
None."""
+ if input_ty_map is None:
+ input_ty_map = ty_map
indent_str = " " * indent
schema = method.schema
# Subclass __ffi_init__ signatures legitimately differ from the parent
@@ -179,7 +188,7 @@ def _render_ffi_init_from_method(
# schema.args[0] is return type, schema.args[1:] are param types.
parts: list[str] = []
for i, arg in enumerate(schema.args[1:]):
- parts.append(f"_{i}: {arg.repr(ty_map)}")
+ parts.append(f"_{i}: {arg.input_repr(input_ty_map)}")
if parts:
params = ", ".join(parts)
return f"{indent_str}def __ffi_init__(self, {params}, /) -> None:
...{ignore}"
@@ -190,6 +199,7 @@ def render_object_ffi_init(
info: ObjectInfo,
ty_map: Callable[[str], str],
indent: int,
+ input_ty_map: Callable[[str], str] | None = None,
) -> list[str]:
"""Render a ``__ffi_init__`` stub when it's not already in TypeMethod.
@@ -203,25 +213,29 @@ def render_object_ffi_init(
# If __ffi_init__ is already in methods (from TypeMethod), methods render
it.
if any(m.schema.name.rsplit(".", 1)[-1] == "__ffi_init__" for m in
info.methods):
return []
- return _render_ffi_init_from_fields(info, ty_map, indent)
+ return _render_ffi_init_from_fields(info, ty_map, indent, input_ty_map)
def render_object_init(
info: ObjectInfo,
ty_map: Callable[[str], str],
indent: int,
+ input_ty_map: Callable[[str], str] | None = None,
) -> list[str]:
"""Render an ``__init__`` stub from init-eligible field metadata."""
if not info.has_init:
return []
- return _render_init_from_fields(info, ty_map, indent)
+ return _render_init_from_fields(info, ty_map, indent, input_ty_map)
def _format_field_params(
info: ObjectInfo,
ty_map: Callable[[str], str],
+ input_ty_map: Callable[[str], str] | None = None,
) -> str:
"""Format init-eligible fields as a parameter string with defaults and
kw_only."""
+ if input_ty_map is None:
+ input_ty_map = ty_map
positional = [f for f in info.init_fields if not f.kw_only]
kw_only = [f for f in info.init_fields if f.kw_only]
@@ -232,15 +246,15 @@ def _format_field_params(
parts: list[str] = []
for f in pos_required:
- parts.append(f"{f.name}: {f.schema.repr(ty_map)}")
+ parts.append(f"{f.name}: {f.schema.input_repr(input_ty_map)}")
for f in pos_default:
- parts.append(f"{f.name}: {f.schema.repr(ty_map)} = ...")
+ parts.append(f"{f.name}: {f.schema.input_repr(input_ty_map)} = ...")
if kw_required or kw_default:
parts.append("*")
for f in kw_required:
- parts.append(f"{f.name}: {f.schema.repr(ty_map)}")
+ parts.append(f"{f.name}: {f.schema.input_repr(input_ty_map)}")
for f in kw_default:
- parts.append(f"{f.name}: {f.schema.repr(ty_map)} = ...")
+ parts.append(f"{f.name}: {f.schema.input_repr(input_ty_map)} =
...")
return ", ".join(parts)
@@ -249,10 +263,11 @@ def _render_init_from_fields(
info: ObjectInfo,
ty_map: Callable[[str], str],
indent: int,
+ input_ty_map: Callable[[str], str] | None = None,
) -> list[str]:
"""Render ``__init__`` from init-eligible field metadata (auto-generated
init)."""
indent_str = " " * indent
- params = _format_field_params(info, ty_map)
+ params = _format_field_params(info, ty_map, input_ty_map)
if params:
return [f"{indent_str}def __init__(self, {params}) -> None: ..."]
return [f"{indent_str}def __init__(self) -> None: ..."]
@@ -262,13 +277,14 @@ def _render_ffi_init_from_fields(
info: ObjectInfo,
ty_map: Callable[[str], str],
indent: int,
+ input_ty_map: Callable[[str], str] | None = None,
) -> list[str]:
"""Render ``__ffi_init__`` stub from field metadata for auto-generated
init."""
indent_str = " " * indent
# Subclass __ffi_init__ signatures legitimately differ from the parent
# (different fields -> different constructor params), so suppress LSP.
ignore = " # ty: ignore[invalid-method-override]"
- params = _format_field_params(info, ty_map)
+ params = _format_field_params(info, ty_map, input_ty_map)
if params:
return [f"{indent_str}def __ffi_init__(self, {params}) -> None:
...{ignore}"]
return [f"{indent_str}def __ffi_init__(self) -> None: ...{ignore}"]
diff --git a/tests/python/test_dataclass_c_class.py
b/tests/python/test_dataclass_c_class.py
index 4d86073f..a1e3f23a 100644
--- a/tests/python/test_dataclass_c_class.py
+++ b/tests/python/test_dataclass_c_class.py
@@ -24,8 +24,8 @@ import warnings
import pytest
import tvm_ffi.testing
from tvm_ffi.core import MISSING, TypeInfo
-from tvm_ffi.dataclasses import Field
-from tvm_ffi.dataclasses.c_class import _attach_field_objects
+from tvm_ffi.dataclasses import Field, field
+from tvm_ffi.dataclasses.c_class import _attach_field_objects, c_class
from tvm_ffi.registry import _warn_missing_field_annotations
from tvm_ffi.testing import (
TestCompare,
@@ -42,6 +42,11 @@ from tvm_ffi.testing import (
# ---------------------------------------------------------------------------
+def test_c_class_dataclass_transform_has_converter() -> None:
+ metadata = c_class.__dataclass_transform__ # ty:
ignore[unresolved-attribute]
+ assert metadata["kwargs"]["converter"] is field().converter
+
+
def test_c_class_custom_init() -> None:
"""c_class preserves user-defined __init__."""
obj = _TestCxxClassBase(v_i64=10, v_i32=20)
diff --git a/tests/python/test_dataclass_py_class.py
b/tests/python/test_dataclass_py_class.py
index 5a436144..49888aa5 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -53,6 +53,20 @@ def _get_type_info(cls: type) -> TypeInfo:
return ret
+def test_py_class_dataclass_transform_has_converter() -> None:
+ metadata = py_class.__dataclass_transform__ # ty:
ignore[unresolved-attribute]
+ assert metadata["kwargs"]["converter"] is field().converter
+
+
+def test_field_accepts_converter_metadata() -> None:
+ def converter(value: Any) -> Any:
+ return value
+
+ f = field(converter=converter)
+ assert isinstance(f, Field)
+ assert f.converter is converter
+
+
# ---------------------------------------------------------------------------
# Low-level helpers for _make_type-based tests
# ---------------------------------------------------------------------------
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 36291188..19c2517d 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -448,23 +448,25 @@ def test_convert_func_with_torch_tensor_cls() -> None:
the outer caller's conversion path, so we verify shape survives the
round-trip rather than isinstance on the return.
"""
+ assert torch is not None
+ torch_mod = torch
calls = 0
def callback(a: Any, b: Any, c: Any) -> Any:
nonlocal calls
calls += 1
- assert isinstance(a, torch.Tensor)
- assert isinstance(b, torch.Tensor)
- assert isinstance(c, torch.Tensor)
+ assert isinstance(a, torch_mod.Tensor)
+ assert isinstance(b, torch_mod.Tensor)
+ assert isinstance(c, torch_mod.Tensor)
assert list(a.shape) == [2]
assert list(b.shape) == [3]
assert list(c.shape) == [4]
return b
- f = tvm_ffi.convert_func(callback, tensor_cls=torch.Tensor)
- a = torch.zeros(2)
- b = torch.ones(3)
- c = torch.full((4,), 2.0)
+ f = tvm_ffi.convert_func(callback, tensor_cls=torch_mod.Tensor)
+ a = torch_mod.zeros(2)
+ b = torch_mod.ones(3)
+ c = torch_mod.full((4,), 2.0)
out = f(a, b, c)
assert calls == 1
assert tuple(out.shape) == (3,)
diff --git a/tests/python/test_stubgen.py b/tests/python/test_stubgen.py
index 5c00a04d..a82c0760 100644
--- a/tests/python/test_stubgen.py
+++ b/tests/python/test_stubgen.py
@@ -16,11 +16,15 @@
# under the License.
from __future__ import annotations
+import itertools
+import typing
from pathlib import Path
import pytest
import tvm_ffi.stub.cli as stub_cli
+from tvm_ffi import Object, method
from tvm_ffi.core import TypeSchema
+from tvm_ffi.dataclasses import py_class
from tvm_ffi.stub import consts as C
from tvm_ffi.stub.cli import _stage_2, _stage_3
from tvm_ffi.stub.file_utils import CodeBlock, FileInfo
@@ -35,23 +39,32 @@ from tvm_ffi.stub.python_generator.codegen import (
generate_python_init,
generate_python_object,
render_func_signature,
+ render_object_ffi_init,
render_object_fields,
+ render_object_init,
render_object_methods,
)
from tvm_ffi.stub.python_generator.utils import ImportItem
from tvm_ffi.stub.utils import (
FuncInfo,
InitConfig,
+ InitFieldInfo,
NamedTypeSchema,
ObjectInfo,
Options,
)
+_counter = itertools.count()
+
def _identity_ty_map(name: str) -> str:
return name
+def _unique_type_key(base: str) -> str:
+ return f"testing.stubgen.{base}_{next(_counter)}"
+
+
def _default_ty_map() -> dict[str, str]:
return PC.TY_MAP_DEFAULTS.copy()
@@ -60,6 +73,10 @@ def _type_suffix(name: str) -> str:
return PC.TY_MAP_DEFAULTS.get(name, name).rsplit(".", 1)[-1]
+def _input_type_suffix(name: str) -> str:
+ return PC.TY_MAP_INPUT_DEFAULTS.get(name, PC.TY_MAP_DEFAULTS.get(name,
name)).rsplit(".", 1)[-1]
+
+
def test_codeblock_from_begin_line_variants() -> None:
cases = [
(f"{C.PYTHON_SYNTAX.begin} global/demo", "global", ("demo", "")),
@@ -301,6 +318,126 @@ def test_objectinfo_gen_fields_container_types() -> None:
]
+def test_funcinfo_gen_uses_input_annotations_for_parameters() -> None:
+ info = FuncInfo(
+ schema=NamedTypeSchema(
+ "demo.echo_list",
+ TypeSchema(
+ "Callable",
+ (
+ TypeSchema("List", (TypeSchema("int"),)),
+ TypeSchema("List", (TypeSchema("int"),)),
+ ),
+ ),
+ ),
+ is_member=False,
+ )
+
+ assert (
+ render_func_signature(info, _type_suffix, indent=0,
input_ty_map=_input_type_suffix)
+ == "def echo_list(_0: Sequence[int], /) -> MutableSequence[int]: ..."
+ )
+
+
+def test_generate_global_funcs_populates_input_defaults_for_partial_ty_map()
-> None:
+ code = CodeBlock(
+ kind="global",
+ param=("demo", "mockpkg"),
+ lineno_start=1,
+ lineno_end=2,
+ lines=[f"{C.PYTHON_SYNTAX.begin} global/demo@mockpkg",
C.PYTHON_SYNTAX.end],
+ )
+ funcs = [
+ FuncInfo(
+ schema=NamedTypeSchema(
+ "demo.echo_list",
+ TypeSchema(
+ "Callable",
+ (
+ TypeSchema("List", (TypeSchema("int"),)),
+ TypeSchema("List", (TypeSchema("int"),)),
+ ),
+ ),
+ ),
+ is_member=False,
+ )
+ ]
+ imports: list[ImportItem] = []
+
+ generate_python_global_funcs(
+ code, funcs, {"List": "collections.abc.MutableSequence"}, imports,
Options()
+ )
+
+ assert code.lines == [
+ f"{C.PYTHON_SYNTAX.begin} global/demo@mockpkg",
+ "# fmt: off",
+ '_FFI_INIT_FUNC("demo", __name__)',
+ "if TYPE_CHECKING:",
+ " def echo_list(_0: Sequence[int], /) -> MutableSequence[int]: ...",
+ "# fmt: on",
+ C.PYTHON_SYNTAX.end,
+ ]
+
+
+def test_objectinfo_gen_init_uses_input_annotations() -> None:
+ info = ObjectInfo(
+ fields=[NamedTypeSchema("items", TypeSchema("List",
(TypeSchema("int"),)))],
+ methods=[],
+ init_fields=[
+ InitFieldInfo(
+ name="items",
+ schema=NamedTypeSchema("items", TypeSchema("List",
(TypeSchema("int"),))),
+ kw_only=False,
+ has_default=False,
+ )
+ ],
+ has_init=True,
+ )
+
+ assert render_object_fields(info, _type_suffix, indent=0) == ["items:
MutableSequence[int]"]
+ assert render_object_init(info, _type_suffix, indent=0,
input_ty_map=_input_type_suffix) == [
+ "def __init__(self, items: Sequence[int]) -> None: ..."
+ ]
+ assert render_object_ffi_init(
+ info, _type_suffix, indent=0, input_ty_map=_input_type_suffix
+ ) == [
+ "def __ffi_init__(self, items: Sequence[int]) -> None: ... # ty: "
+ "ignore[invalid-method-override]"
+ ]
+
+
+def test_py_class_method_metadata_renders_stub_signature() -> None:
+ @py_class(_unique_type_key("MethodMetadata"))
+ class MethodMetadata(Object):
+ value: int
+
+ @method
+ def describe(self, values: typing.List[int], prefix: str) -> str: #
noqa: UP006
+ return f"{prefix}:{self.value}:{len(values)}"
+
+ @method
+ @staticmethod
+ def normalize(values: typing.List[int]) -> typing.List[int]: # noqa:
UP006
+ return values
+
+ info = ObjectInfo.from_type_info(MethodMetadata.__tvm_ffi_type_info__) #
ty: ignore[unresolved-attribute]
+ methods = {method.schema.name: method for method in info.methods}
+ describe_schema = methods["describe"].schema
+
+ assert describe_schema.origin == "Callable"
+ assert [arg.origin for arg in describe_schema.args] == [
+ "str",
+ MethodMetadata.__tvm_ffi_type_info__.type_key, # ty:
ignore[unresolved-attribute]
+ "List",
+ "str",
+ ]
+ assert render_object_methods(info, _type_suffix, indent=0,
input_ty_map=_input_type_suffix) == [
+ "def describe(self, _1: Sequence[int], _2: str, /) -> str: ...",
+ "@staticmethod",
+ "def normalize(_0: Sequence[int], /) -> MutableSequence[int]: ...",
+ ]
+
+
def test_generate_global_funcs_updates_block() -> None:
code = CodeBlock(
kind="global",
diff --git a/tests/python/test_type_converter.py
b/tests/python/test_type_converter.py
index 82a0d7fb..a639e5d7 100644
--- a/tests/python/test_type_converter.py
+++ b/tests/python/test_type_converter.py
@@ -31,13 +31,14 @@ import pytest
import tvm_ffi
from tvm_ffi.core import (
CAny,
+ Object,
ObjectConvertible,
TypeSchema,
_lookup_type_attr,
_object_type_key_to_index,
_to_py_class_value,
)
-from tvm_ffi.dataclasses import IntEnum, StrEnum, entry
+from tvm_ffi.dataclasses import IntEnum, StrEnum, entry, py_class
from tvm_ffi.testing import (
TestIntPair,
TestObjectBase,
@@ -3443,6 +3444,54 @@ class TestSTLOriginParsing:
assert s.origin == "Object"
+def _output_ty_map(name: str) -> str:
+ return {
+ "Array": "Sequence",
+ "List": "MutableSequence",
+ "Map": "Mapping",
+ "Dict": "MutableMapping",
+ }.get(name, name)
+
+
+def _input_ty_map(name: str) -> str:
+ return {
+ "Array": "Sequence",
+ "List": "Sequence",
+ "Map": "Mapping",
+ "Dict": "Mapping",
+ }.get(name, name)
+
+
+class TestTypeSchemaAnnotationRendering:
+ """Input annotations are widened without changing output annotations."""
+
+ def test_container_input_output_repr(self) -> None:
+ """List and Dict render differently for input and output positions."""
+ schema = S("Dict", S("str"), S("List", S("int")))
+
+ assert schema.output_repr(_output_ty_map) == "MutableMapping[str,
MutableSequence[int]]"
+ assert schema.input_repr(_input_ty_map) == "Mapping[str,
Sequence[int]]"
+ assert schema.repr(_output_ty_map) ==
schema.output_repr(_output_ty_map)
+
+ def test_object_convert_type_schema_attr_widens_input_only(self) -> None:
+ """__ffi_convert_type_schema__ affects input annotations only."""
+ type_key = _unique_type_key("ConvertTypeSchema")
+ convert_schema = (
+
f'{{"type":"Union","args":[{{"type":"int"}},{{"type":"str"}},{{"type":"{type_key}"}}]}}'
+ )
+
+ @py_class(type_key)
+ class ExprLike(Object):
+ __ffi_convert_type_schema__ = convert_schema
+
+ def ty_map(name: str) -> str:
+ return "ExprLike" if name == type_key else name
+
+ schema = TypeSchema.from_annotation(ExprLike)
+ assert schema.output_repr(ty_map) == "ExprLike"
+ assert schema.input_repr(ty_map) == "int | str | ExprLike"
+
+
# ---------------------------------------------------------------------------
# Category 64: Zero-copy container conversion
# ---------------------------------------------------------------------------