This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new bb2e4586 [PY] Distinguish FFI input and output annotations (#621)
bb2e4586 is described below

commit bb2e4586fff47ed75331b44639b1fb0ffb6f1e50
Author: Yixin Dong <[email protected]>
AuthorDate: Tue Jun 23 10:23:26 2026 -0400

    [PY] Distinguish FFI input and output annotations (#621)
    
    ## Summary
    - Add separate input/output rendering APIs for `TypeSchema` so stubgen
    can widen callable and constructor inputs without weakening returned or
    field types.
    - Add a `__ffi_convert_type_schema__` TypeAttr hook for object
    conversions and wire `py_class`/`c_class` dataclass transform converter
    metadata.
    - Generate `type_schema` metadata for Python-defined `@method`
    TypeMethods so reflection-driven stubgen can render their signatures.
    - Cover container asymmetry, object convert schema recursion, stubgen
    signatures, partial type maps, `@method` metadata, and dataclass
    metadata with regression tests.
    - Update the PR CI workflow to avoid Apache-disallowed third-party setup
    actions and run pre-commit without syncing the project during lint.
    
    ## Validation
    - pre-commit hooks during commits, including ruff, ty, clang-format,
    cython-lint, CMake lint/format, and workflow YAML checks
    - `python -m pytest
    
tests/python/test_stubgen.py::test_py_class_method_metadata_renders_stub_signature
    tests/python/test_typed_method.py -q`
    - `python -m pytest tests/python/test_stubgen.py -q`
    - `CUDA_VISIBLE_DEVICES="" python -m pytest tests/python -q` (`2322
    passed, 18 skipped, 2 xfailed`)
    - Manual output dumped under
    `tmp/2026-06-14-update-ffi/manual_check.out`
    - GitHub CI is green for lint, docs, Ubuntu, Ubuntu ARM, macOS, and
    Windows
---
 CMakeLists.txt                                  |   3 +
 cmake/Utils/AddGoogleTest.cmake                 |   2 +-
 include/tvm/ffi/reflection/accessor.h           |   8 ++
 python/tvm_ffi/core.pyi                         |   2 +
 python/tvm_ffi/cython/type_info.pxi             |  69 ++++++++++--
 python/tvm_ffi/dataclasses/c_class.py           |   9 +-
 python/tvm_ffi/dataclasses/field.py             |  19 +++-
 python/tvm_ffi/dataclasses/py_class.py          |  81 ++++++++++++--
 python/tvm_ffi/stub/python_generator/codegen.py |  30 +++++-
 python/tvm_ffi/stub/python_generator/consts.py  |   7 ++
 python/tvm_ffi/stub/python_generator/utils.py   |  44 +++++---
 tests/python/test_dataclass_c_class.py          |   9 +-
 tests/python/test_dataclass_py_class.py         |  14 +++
 tests/python/test_function.py                   |  16 +--
 tests/python/test_stubgen.py                    | 137 ++++++++++++++++++++++++
 tests/python/test_type_converter.py             |  51 ++++++++-
 16 files changed, 452 insertions(+), 49 deletions(-)

diff --git a/CMakeLists.txt b/CMakeLists.txt
index d306bb2a..d0cb8f08 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -247,10 +247,13 @@ if (TVM_FFI_BUILD_PYTHON_MODULE)
   set(_cython_sources
       ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/core.pyx
       ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/base.pxi
+      ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/type_info.pxi
       ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/device.pxi
       ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/dtype.pxi
       ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/error.pxi
       ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/function.pxi
+      ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/pycallback.pxi
+      
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/pyclass_type_converter.pxi
       ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/tensor.pxi
       ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/object.pxi
       ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm_ffi/cython/string.pxi
diff --git a/cmake/Utils/AddGoogleTest.cmake b/cmake/Utils/AddGoogleTest.cmake
index cd8d0931..8b4800da 100644
--- a/cmake/Utils/AddGoogleTest.cmake
+++ b/cmake/Utils/AddGoogleTest.cmake
@@ -91,7 +91,7 @@ macro (TVM_FFI_ADD_GTEST target_name)
   target_link_libraries(${target_name} PRIVATE gtest_main)
   gtest_discover_tests(
     ${target_name}
-    WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} TEST_DISCOVERY_TIMEOUT 600 
DISCOVERY_MODE PRE_TEST
+    WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} DISCOVERY_TIMEOUT 600 
DISCOVERY_MODE PRE_TEST
     PROPERTIES VS_DEBUGGER_WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}"
   )
   set_target_properties(${target_name} PROPERTIES FOLDER tests)
diff --git a/include/tvm/ffi/reflection/accessor.h 
b/include/tvm/ffi/reflection/accessor.h
index 80de4931..d0aa734c 100644
--- a/include/tvm/ffi/reflection/accessor.h
+++ b/include/tvm/ffi/reflection/accessor.h
@@ -350,6 +350,14 @@ inline constexpr const char* kInit = "__ffi_init__";
  * Signature: ``(AnyView src) -> TSelf``, where ``TSelf`` is a subclass of 
ObjectRef.
  */
 inline constexpr const char* kConvert = "__ffi_convert__";
+/*!
+ * \brief Type schema accepted by ``kConvert`` before it returns ``TSelf``.
+ *
+ * Stored as a JSON type schema string or schema-like object.  Python stub
+ * generation uses this attribute to render widened input annotations while
+ * keeping output annotations precise.
+ */
+inline constexpr const char* kConvertTypeSchema = 
"__ffi_convert_type_schema__";
 /*!
  * \brief Shallow-copy factory.
  *
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index c6512365..59e68ddb 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -274,6 +274,8 @@ class TypeSchema:
     @staticmethod
     def from_annotation(annotation: object) -> TypeSchema: ...
     def repr(self, ty_map: Callable[[str], str] | None = None) -> str: ...
+    def input_repr(self, ty_map: Callable[[str], str] | None = None) -> str: 
...
+    def output_repr(self, ty_map: Callable[[str], str] | None = None) -> str: 
...
     def check_value(self, value: object) -> None: ...
     def convert(self, value: object) -> CAny: ...
     def to_json(self) -> dict[str, Any]: ...
diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
index 960d5cd5..f14fb6b8 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -116,6 +116,8 @@ _TYPE_SCHEMA_ORIGIN_CONVERTER = {
     "ObjectRValueRef": "Object",
 }
 
+cdef str _TYPE_ATTR_FFI_CONVERT_TYPE_SCHEMA = "__ffi_convert_type_schema__"
+
 # Sentinel for structural types (Optional, Union) that have no single 
type_index
 _ORIGIN_TYPE_INDEX_STRUCTURAL = -2
 # Sentinel for unknown/unresolved origins
@@ -217,6 +219,19 @@ class TypeSchema:
     def __repr__(self) -> str:
         return self.repr(ty_map=None)
 
+    @staticmethod
+    def _from_maybe_json(raw: object) -> "TypeSchema":
+        """Construct a TypeSchema from a TypeSchema, JSON string, or JSON 
dict."""
+        if isinstance(raw, TypeSchema):
+            return raw
+        if isinstance(raw, str):
+            return TypeSchema.from_json_str(raw)
+        if isinstance(raw, dict):
+            return TypeSchema.from_json_obj(raw)
+        raise TypeError(
+            f"expected TypeSchema, JSON string, or JSON dict, got 
{type(raw).__name__}"
+        )
+
     @staticmethod
     def from_json_obj(obj: dict[str, Any]) -> "TypeSchema":
         """Construct a :class:`TypeSchema` from a parsed JSON object.
@@ -517,12 +532,40 @@ class TypeSchema:
             assert s.repr() == "Array[int]"
 
         """
-        if ty_map is None:
-            origin = self.origin
-        else:
-            origin = ty_map(self.origin)
+        return self.output_repr(ty_map)
+
+    def input_repr(self, ty_map: "Optional[Callable[[str], str]]" = None) -> 
str:
+        """Render the Python input annotation accepted by this schema."""
+        return self._repr_impl(ty_map, input_mode=True, 
expanded_convert_types=frozenset())
+
+    def output_repr(self, ty_map: "Optional[Callable[[str], str]]" = None) -> 
str:
+        """Render the precise Python output annotation produced by this 
schema."""
+        return self._repr_impl(ty_map, input_mode=False, 
expanded_convert_types=frozenset())
+
+    def _repr_impl(
+        self,
+        ty_map: "Optional[Callable[[str], str]]",
+        input_mode: bool,
+        expanded_convert_types: "frozenset[int]",
+    ) -> str:
+        if input_mode and self.origin_type_index >= kTVMFFIStaticObjectBegin:
+            if self.origin_type_index not in expanded_convert_types:
+                raw_convert_schema = _lookup_type_attr(
+                    self.origin_type_index, _TYPE_ATTR_FFI_CONVERT_TYPE_SCHEMA
+                )
+                if raw_convert_schema is not None:
+                    return 
TypeSchema._from_maybe_json(raw_convert_schema)._repr_impl(
+                        ty_map,
+                        input_mode=True,
+                        expanded_convert_types=expanded_convert_types | 
{self.origin_type_index},
+                    )
+
+        origin = self.origin if ty_map is None else ty_map(self.origin)
         schema_args = self.args
-        args = [i.repr(ty_map) for i in (() if schema_args is None else 
schema_args)]
+        args = [
+            i._repr_impl(ty_map, input_mode, expanded_convert_types)
+            for i in (() if schema_args is None else schema_args)
+        ]
         if origin == "Union":
             return " | ".join(args)
         elif origin == "Optional":
@@ -1129,8 +1172,8 @@ cdef _register_py_methods(int32_t type_index, list 
py_methods, frozenset type_at
     ----------
     type_index : int
         The runtime type index of the type.
-    py_methods : list[tuple[str, Any, bool]]
-        Each entry is ``(name, value, is_static)``.
+    py_methods : list[tuple[str, Any, bool, str | None]]
+        Each entry is ``(name, value, is_static, metadata_json)``.
     type_attr_names : frozenset[str]
         Names to register as TypeAttrColumn instead of TypeMethod.
     """
@@ -1139,11 +1182,12 @@ cdef _register_py_methods(int32_t type_index, list 
py_methods, frozenset type_at
     cdef TVMFFIAny sentinel_any
     cdef int c_api_ret_code
     cdef ByteArrayArg name_arg
+    cdef ByteArrayArg metadata_arg
 
     sentinel_any.type_index = kTVMFFINone
     sentinel_any.v_int64 = 0
 
-    for name, func, is_static in py_methods:
+    for name, func, is_static, metadata_json in py_methods:
         func_any.type_index = kTVMFFINone
         func_any.v_int64 = 0
         try:
@@ -1169,8 +1213,13 @@ cdef _register_py_methods(int32_t type_index, list 
py_methods, frozenset type_at
                 method_info.doc.size = 0
                 method_info.flags = kTVMFFIFieldFlagBitMaskIsStaticMethod if 
is_static else 0
                 method_info.method = func_any
-                method_info.metadata.data = NULL
-                method_info.metadata.size = 0
+                if metadata_json is not None:
+                    metadata_bytes = c_str(metadata_json)
+                    metadata_arg = ByteArrayArg(metadata_bytes)
+                    method_info.metadata = metadata_arg.cdata
+                else:
+                    method_info.metadata.data = NULL
+                    method_info.metadata.size = 0
                 CHECK_CALL(TVMFFITypeRegisterMethod(type_index, &method_info))
         finally:
             if func_any.type_index >= kTVMFFIStaticObjectBegin and 
func_any.v_obj != NULL:
diff --git a/python/tvm_ffi/dataclasses/c_class.py 
b/python/tvm_ffi/dataclasses/c_class.py
index 501c56b0..2799669b 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -24,7 +24,7 @@ from typing import Any, TypeVar
 
 from typing_extensions import dataclass_transform
 
-from .field import Field
+from .field import Field, _field_converter, field
 
 _T = TypeVar("_T", bound=type)
 
@@ -60,7 +60,12 @@ def _attach_field_objects(cls: type, type_info: Any) -> None:
         tf.dataclass_field = f
 
 
-@dataclass_transform(eq_default=False, order_default=False)
+@dataclass_transform(
+    eq_default=False,
+    order_default=False,
+    field_specifiers=(field, Field),
+    converter=_field_converter,
+)
 def c_class(
     type_key: str,
     *,
diff --git a/python/tvm_ffi/dataclasses/field.py 
b/python/tvm_ffi/dataclasses/field.py
index 08c2ea3f..ab7dd66f 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -37,6 +37,11 @@ else:
         """Sentinel type: annotations after ``_: KW_ONLY`` are keyword-only."""
 
 
+def _field_converter(value: Any) -> Any:
+    """Static-analysis marker for fields whose values are converted by FFI."""
+    return value
+
+
 class Field:
     """Descriptor for a single field in a Python-defined TVM-FFI type.
 
@@ -102,12 +107,16 @@ class Field:
           parameters that reference outer-scope vars.
     doc : str | None
         Optional docstring for the field.
+    converter : Callable[[Any], Any]
+        Static-analysis marker for field conversion. Runtime conversion is
+        still handled by the FFI type converter.
 
     """
 
     __slots__ = (
         "_ty_schema",
         "compare",
+        "converter",
         "default",
         "default_factory",
         "doc",
@@ -130,6 +139,7 @@ class Field:
     repr: bool
     hash: bool | None
     compare: bool
+    converter: Callable[[Any], Any]
     kw_only: bool | None
     structural_eq: str | None
     doc: str | None
@@ -158,6 +168,7 @@ class Field:
         kw_only: bool | None = False,
         structural_eq: str | None = None,
         doc: str | None = None,
+        converter: Callable[[Any], Any] = _field_converter,
     ) -> None:
         # MISSING means "parameter not provided".
         # An explicit None from the user fails the callable() check,
@@ -185,12 +196,13 @@ class Field:
         self.repr = repr
         self.hash = hash
         self.compare = compare
+        self.converter = converter
         self.kw_only = kw_only
         self.structural_eq = structural_eq
         self.doc = doc
 
 
-def field(
+def field(  # noqa: PLR0913
     *,
     default: object = MISSING,
     default_factory: Callable[[], object] | None = MISSING,  # type: 
ignore[assignment]
@@ -202,6 +214,7 @@ def field(
     kw_only: bool | None = None,
     structural_eq: str | None = None,
     doc: str | None = None,
+    converter: Callable[[Any], Any] = _field_converter,
 ) -> Any:
     """Customize a field in a ``@py_class``-decorated class.
 
@@ -248,6 +261,9 @@ def field(
         binding.
     doc
         Optional docstring for the field.
+    converter
+        Static-analysis marker for field conversion. Runtime conversion is
+        still handled by the FFI type converter.
 
     Returns
     -------
@@ -282,4 +298,5 @@ def field(
         kw_only=kw_only,
         structural_eq=structural_eq,
         doc=doc,
+        converter=converter,
     )
diff --git a/python/tvm_ffi/dataclasses/py_class.py 
b/python/tvm_ffi/dataclasses/py_class.py
index bb2c8a1f..28e8f957 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -18,6 +18,8 @@
 
 from __future__ import annotations
 
+import inspect
+import json
 import sys
 import typing
 from collections.abc import Callable
@@ -31,7 +33,7 @@ from .. import core
 from .._dunder import _install_dataclass_dunders
 from ..core import MISSING, TypeSchema
 from ..registry import _add_class_attrs
-from .field import KW_ONLY, Field, field
+from .field import KW_ONLY, Field, _field_converter, field
 
 _T = TypeVar("_T", bound=type)
 
@@ -79,6 +81,11 @@ class _PendingClass:
 #: :func:`_flush_pending`.
 _PENDING_CLASSES: list[_PendingClass] = []
 
+
+class _DeferredAnnotation(Exception):
+    """Raised when method annotations must wait for a later py_class 
registration."""
+
+
 #: Per-module mapping of ``class.__name__ → class`` for every
 #: ``@py_class``-decorated type.  Used as *localns* when resolving
 #: annotations so that mutual references between classes in the same
@@ -328,7 +335,52 @@ def _validate_method_name(cls: type, name: str) -> None:
         )
 
 
-def _collect_py_methods(cls: type) -> list[tuple[str, Any, bool]] | None:
+def _method_type_schema_json(
+    cls: type,
+    func: Any,
+    is_static: bool,
+    globalns: dict[str, Any],
+) -> str:
+    """Build reflection metadata for a Python-defined FFI TypeMethod."""
+    kwargs: dict[str, Any] = {"globalns": globalns, "localns": 
_build_localns(cls)}
+    if sys.version_info >= (3, 11):
+        kwargs["include_extras"] = True
+    try:
+        hints = typing.get_type_hints(func, **kwargs)
+    except (NameError, AttributeError):
+        kwargs["localns"] = _build_localns(cls, cross_module=True)
+        try:
+            hints = typing.get_type_hints(func, **kwargs)
+        except (NameError, AttributeError) as err:
+            raise _DeferredAnnotation from err
+
+    sig = inspect.signature(func)
+    params = list(sig.parameters.values())
+    if any(
+        param.kind in (inspect.Parameter.VAR_POSITIONAL, 
inspect.Parameter.VAR_KEYWORD)
+        for param in params
+    ):
+        return json.dumps({"type_schema": TypeSchema("Callable").to_json()})
+
+    arg_schemas: list[TypeSchema] = []
+    if not is_static:
+        arg_schemas.append(TypeSchema.from_annotation(cls))
+        params = params[1:]
+
+    for param in params:
+        if param.kind is inspect.Parameter.KEYWORD_ONLY:
+            return json.dumps({"type_schema": 
TypeSchema("Callable").to_json()})
+        annotation = hints.get(param.name, Any)
+        arg_schemas.append(TypeSchema.from_annotation(annotation))
+
+    ret_annotation = hints.get("return", Any)
+    ret_schema = TypeSchema.from_annotation(ret_annotation)
+    return json.dumps({"type_schema": TypeSchema("Callable", (ret_schema, 
*arg_schemas)).to_json()})
+
+
+def _collect_py_methods(
+    cls: type, globalns: dict[str, Any] | None = None
+) -> list[tuple[Any, ...]] | None:
     """Extract FFI-registered entries from a ``@py_class`` body.
 
     Two sources are collected:
@@ -349,10 +401,14 @@ def _collect_py_methods(cls: type) -> list[tuple[str, 
Any, bool]] | None:
     and Python protocol dunders cannot be ``@method``-decorated; those
     are reserved by the TypeAttrColumn and Python semantics respectively.
 
-    Returns the ``(name, value, is_static)`` list, or :data:`None` when
-    no entries were found.
+    Returns the ``(name, value, is_static, metadata_json)`` list, or
+    :data:`None` when no entries were found.
     """
-    methods: list[tuple[str, Any, bool]] = []
+    legacy_shape = globalns is None
+    if globalns is None:
+        globalns = vars(sys.modules[cls.__module__])
+
+    methods: list[tuple[Any, ...]] = []
     for name, value in cls.__dict__.items():
         marked = _is_method_marked(value)
         if name not in _FFI_RECOGNIZED_METHODS and not marked:
@@ -374,7 +430,13 @@ def _collect_py_methods(cls: type) -> list[tuple[str, Any, 
bool]] | None:
             )
         is_static = isinstance(value, staticmethod)
         func = value.__func__ if is_static else value
-        methods.append((name, func, is_static))
+        metadata_json = None
+        if marked:
+            metadata_json = _method_type_schema_json(cls, func, is_static, 
globalns)
+        if legacy_shape:
+            methods.append((name, func, is_static))
+        else:
+            methods.append((name, func, is_static, metadata_json))
     return methods if methods else None
 
 
@@ -447,7 +509,10 @@ def _register_fields_into_type(
             assert f.name is not None
             fields_map[f.name] = f
     own_fields = list(fields_map.values())
-    py_methods = _collect_py_methods(cls)
+    try:
+        py_methods = _collect_py_methods(cls, globalns)
+    except _DeferredAnnotation:
+        return False
 
     # Register fields and type-level structural eq/hash kind with the C layer.
     structure_kind = _STRUCTURE_KIND_MAP.get(params.get("structural_eq"))
@@ -603,6 +668,7 @@ _FFI_TYPE_ATTR_NAMES: frozenset[str] = frozenset(
         "__ffi_eq__",
         "__ffi_compare__",
         "__ffi_convert__",
+        "__ffi_convert_type_schema__",
         "__any_hash__",
         "__any_equal__",
         "__s_equal__",
@@ -627,6 +693,7 @@ _FFI_RECOGNIZED_METHODS: frozenset[str] = 
_FFI_TYPE_ATTR_NAMES
     eq_default=False,
     order_default=False,
     field_specifiers=(field, Field),
+    converter=_field_converter,
 )
 def py_class(  # noqa: PLR0913
     cls_or_type_key: type | str | None = None,
diff --git a/python/tvm_ffi/stub/python_generator/codegen.py 
b/python/tvm_ffi/stub/python_generator/codegen.py
index c3ffe76d..0f813464 100644
--- a/python/tvm_ffi/stub/python_generator/codegen.py
+++ b/python/tvm_ffi/stub/python_generator/codegen.py
@@ -30,6 +30,7 @@ from typing import Callable
 from .. import consts as C
 from ..file_utils import CodeBlock
 from ..utils import FuncInfo, InitConfig, ObjectInfo, Options
+from . import consts as PC
 from .utils import (
     ImportItem,
     render_func_signature,
@@ -103,6 +104,16 @@ def _type_suffix_and_record(
     return _run
 
 
+def _make_input_ty_map(ty_map: dict[str, str]) -> dict[str, str]:
+    """Derive input-side defaults without overriding explicit ty-map 
entries."""
+    input_ty_map = ty_map.copy()
+    for key, input_default in PC.TY_MAP_INPUT_DEFAULTS.items():
+        output_default = PC.TY_MAP_DEFAULTS.get(key)
+        if ty_map.get(key, output_default) == output_default:
+            input_ty_map[key] = input_default
+    return input_ty_map
+
+
 def generate_python_global_funcs(
     code: CodeBlock,
     global_funcs: list[FuncInfo],
@@ -136,11 +147,17 @@ def generate_python_global_funcs(
     )
     func_names = {f.schema.name.rsplit(".", 1)[-1] for f in global_funcs}
     fn_ty_map = _type_suffix_and_record(ty_map, imports, func_names=func_names)
+    input_fn_ty_map = _type_suffix_and_record(
+        _make_input_ty_map(ty_map), imports, func_names=func_names
+    )
     results: list[str] = [
         "# fmt: off",
         f'_FFI_INIT_FUNC("{prefix}", __name__)',
         "if TYPE_CHECKING:",
-        *[render_func_signature(func, fn_ty_map, opt.indent) for func in 
global_funcs],
+        *[
+            render_func_signature(func, fn_ty_map, opt.indent, 
input_ty_map=input_fn_ty_map)
+            for func in global_funcs
+        ],
         "# fmt: on",
     ]
     indent = " " * code.indent
@@ -166,12 +183,17 @@ def generate_python_object(
     info = obj_info
     method_names = {m.schema.name.rsplit(".", 1)[-1] for m in info.methods}
     fn_ty_map = _type_suffix_and_record(ty_map, imports, 
func_names=method_names)
-    init_lines = render_object_init(info, fn_ty_map, opt.indent)
-    ffi_init_lines = render_object_ffi_init(info, fn_ty_map, opt.indent)
+    input_fn_ty_map = _type_suffix_and_record(
+        _make_input_ty_map(ty_map), imports, func_names=method_names
+    )
+    init_lines = render_object_init(info, fn_ty_map, opt.indent, 
input_ty_map=input_fn_ty_map)
+    ffi_init_lines = render_object_ffi_init(
+        info, fn_ty_map, opt.indent, input_ty_map=input_fn_ty_map
+    )
     type_checking_lines = [
         *init_lines,
         *ffi_init_lines,
-        *render_object_methods(info, fn_ty_map, opt.indent),
+        *render_object_methods(info, fn_ty_map, opt.indent, 
input_ty_map=input_fn_ty_map),
     ]
     if type_checking_lines:
         imports.append(
diff --git a/python/tvm_ffi/stub/python_generator/consts.py 
b/python/tvm_ffi/stub/python_generator/consts.py
index a788bc68..5f931841 100644
--- a/python/tvm_ffi/stub/python_generator/consts.py
+++ b/python/tvm_ffi/stub/python_generator/consts.py
@@ -38,6 +38,13 @@ TY_MAP_DEFAULTS = {
     "Device": "ffi.Device",
 }
 
+TY_MAP_INPUT_DEFAULTS = {
+    "Array": "collections.abc.Sequence",
+    "List": "collections.abc.Sequence",
+    "Map": "collections.abc.Mapping",
+    "Dict": "collections.abc.Mapping",
+}
+
 # TODO(@junrushao): Make it configurable
 #: Module-prefix rewrites applied when constructing a Python ``import`` path.
 MOD_MAP = {
diff --git a/python/tvm_ffi/stub/python_generator/utils.py 
b/python/tvm_ffi/stub/python_generator/utils.py
index 02c0f7b4..37eff8f1 100644
--- a/python/tvm_ffi/stub/python_generator/utils.py
+++ b/python/tvm_ffi/stub/python_generator/utils.py
@@ -102,8 +102,11 @@ def render_func_signature(
     func: FuncInfo,
     ty_map: Callable[[str], str],
     indent: int,
+    input_ty_map: Callable[[str], str] | None = None,
 ) -> str:
     """Render a function signature string for ``func``."""
+    if input_ty_map is None:
+        input_ty_map = ty_map
     func_name = func.schema.name.rsplit(".", 1)[-1]
     buf = StringIO()
     buf.write(" " * indent)
@@ -121,12 +124,12 @@ def render_func_signature(
             buf.write("self, ")
         else:
             buf.write(f"_{i}: ")
-            buf.write(arg.repr(ty_map))
+            buf.write(arg.input_repr(input_ty_map))
             buf.write(", ")
     if arg_args:
         buf.write("/")
     buf.write(") -> ")
-    buf.write(arg_ret.repr(ty_map))
+    buf.write(arg_ret.output_repr(ty_map))
     buf.write(": ...")
     return buf.getvalue()
 
@@ -138,15 +141,18 @@ def render_object_fields(
 ) -> list[str]:
     """Render field definitions for ``info``."""
     indent_str = " " * indent
-    return [f"{indent_str}{field.name}: {field.repr(ty_map)}" for field in 
info.fields]
+    return [f"{indent_str}{field.name}: {field.output_repr(ty_map)}" for field 
in info.fields]
 
 
 def render_object_methods(
     info: ObjectInfo,
     ty_map: Callable[[str], str],
     indent: int,
+    input_ty_map: Callable[[str], str] | None = None,
 ) -> list[str]:
     """Render method definitions for ``info``."""
+    if input_ty_map is None:
+        input_ty_map = ty_map
     indent_str = " " * indent
     ret = []
     for method in info.methods:
@@ -154,11 +160,11 @@ def render_object_methods(
         if func_name == "__ffi_init__":
             # __ffi_init__ is installed as an instance method (self, *args, 
**kwargs) -> None
             # by _install_ffi_init_attr, regardless of the C++ static 
registration.
-            ret.append(_render_ffi_init_from_method(method, ty_map, indent))
+            ret.append(_render_ffi_init_from_method(method, ty_map, indent, 
input_ty_map))
             continue
         if not method.is_member:
             ret.append(f"{indent_str}@staticmethod")
-        ret.append(render_func_signature(method, ty_map, indent))
+        ret.append(render_func_signature(method, ty_map, indent, input_ty_map))
     return ret
 
 
@@ -166,8 +172,11 @@ def _render_ffi_init_from_method(
     method: FuncInfo,
     ty_map: Callable[[str], str],
     indent: int,
+    input_ty_map: Callable[[str], str] | None = None,
 ) -> str:
     """Render ``__ffi_init__`` TypeMethod as an instance method returning 
None."""
+    if input_ty_map is None:
+        input_ty_map = ty_map
     indent_str = " " * indent
     schema = method.schema
     # Subclass __ffi_init__ signatures legitimately differ from the parent
@@ -179,7 +188,7 @@ def _render_ffi_init_from_method(
     # schema.args[0] is return type, schema.args[1:] are param types.
     parts: list[str] = []
     for i, arg in enumerate(schema.args[1:]):
-        parts.append(f"_{i}: {arg.repr(ty_map)}")
+        parts.append(f"_{i}: {arg.input_repr(input_ty_map)}")
     if parts:
         params = ", ".join(parts)
         return f"{indent_str}def __ffi_init__(self, {params}, /) -> None: 
...{ignore}"
@@ -190,6 +199,7 @@ def render_object_ffi_init(
     info: ObjectInfo,
     ty_map: Callable[[str], str],
     indent: int,
+    input_ty_map: Callable[[str], str] | None = None,
 ) -> list[str]:
     """Render a ``__ffi_init__`` stub when it's not already in TypeMethod.
 
@@ -203,25 +213,29 @@ def render_object_ffi_init(
     # If __ffi_init__ is already in methods (from TypeMethod), methods render 
it.
     if any(m.schema.name.rsplit(".", 1)[-1] == "__ffi_init__" for m in 
info.methods):
         return []
-    return _render_ffi_init_from_fields(info, ty_map, indent)
+    return _render_ffi_init_from_fields(info, ty_map, indent, input_ty_map)
 
 
 def render_object_init(
     info: ObjectInfo,
     ty_map: Callable[[str], str],
     indent: int,
+    input_ty_map: Callable[[str], str] | None = None,
 ) -> list[str]:
     """Render an ``__init__`` stub from init-eligible field metadata."""
     if not info.has_init:
         return []
-    return _render_init_from_fields(info, ty_map, indent)
+    return _render_init_from_fields(info, ty_map, indent, input_ty_map)
 
 
 def _format_field_params(
     info: ObjectInfo,
     ty_map: Callable[[str], str],
+    input_ty_map: Callable[[str], str] | None = None,
 ) -> str:
     """Format init-eligible fields as a parameter string with defaults and 
kw_only."""
+    if input_ty_map is None:
+        input_ty_map = ty_map
     positional = [f for f in info.init_fields if not f.kw_only]
     kw_only = [f for f in info.init_fields if f.kw_only]
 
@@ -232,15 +246,15 @@ def _format_field_params(
 
     parts: list[str] = []
     for f in pos_required:
-        parts.append(f"{f.name}: {f.schema.repr(ty_map)}")
+        parts.append(f"{f.name}: {f.schema.input_repr(input_ty_map)}")
     for f in pos_default:
-        parts.append(f"{f.name}: {f.schema.repr(ty_map)} = ...")
+        parts.append(f"{f.name}: {f.schema.input_repr(input_ty_map)} = ...")
     if kw_required or kw_default:
         parts.append("*")
         for f in kw_required:
-            parts.append(f"{f.name}: {f.schema.repr(ty_map)}")
+            parts.append(f"{f.name}: {f.schema.input_repr(input_ty_map)}")
         for f in kw_default:
-            parts.append(f"{f.name}: {f.schema.repr(ty_map)} = ...")
+            parts.append(f"{f.name}: {f.schema.input_repr(input_ty_map)} = 
...")
 
     return ", ".join(parts)
 
@@ -249,10 +263,11 @@ def _render_init_from_fields(
     info: ObjectInfo,
     ty_map: Callable[[str], str],
     indent: int,
+    input_ty_map: Callable[[str], str] | None = None,
 ) -> list[str]:
     """Render ``__init__`` from init-eligible field metadata (auto-generated 
init)."""
     indent_str = " " * indent
-    params = _format_field_params(info, ty_map)
+    params = _format_field_params(info, ty_map, input_ty_map)
     if params:
         return [f"{indent_str}def __init__(self, {params}) -> None: ..."]
     return [f"{indent_str}def __init__(self) -> None: ..."]
@@ -262,13 +277,14 @@ def _render_ffi_init_from_fields(
     info: ObjectInfo,
     ty_map: Callable[[str], str],
     indent: int,
+    input_ty_map: Callable[[str], str] | None = None,
 ) -> list[str]:
     """Render ``__ffi_init__`` stub from field metadata for auto-generated 
init."""
     indent_str = " " * indent
     # Subclass __ffi_init__ signatures legitimately differ from the parent
     # (different fields -> different constructor params), so suppress LSP.
     ignore = "  # ty: ignore[invalid-method-override]"
-    params = _format_field_params(info, ty_map)
+    params = _format_field_params(info, ty_map, input_ty_map)
     if params:
         return [f"{indent_str}def __ffi_init__(self, {params}) -> None: 
...{ignore}"]
     return [f"{indent_str}def __ffi_init__(self) -> None: ...{ignore}"]
diff --git a/tests/python/test_dataclass_c_class.py 
b/tests/python/test_dataclass_c_class.py
index 4d86073f..a1e3f23a 100644
--- a/tests/python/test_dataclass_c_class.py
+++ b/tests/python/test_dataclass_c_class.py
@@ -24,8 +24,8 @@ import warnings
 import pytest
 import tvm_ffi.testing
 from tvm_ffi.core import MISSING, TypeInfo
-from tvm_ffi.dataclasses import Field
-from tvm_ffi.dataclasses.c_class import _attach_field_objects
+from tvm_ffi.dataclasses import Field, field
+from tvm_ffi.dataclasses.c_class import _attach_field_objects, c_class
 from tvm_ffi.registry import _warn_missing_field_annotations
 from tvm_ffi.testing import (
     TestCompare,
@@ -42,6 +42,11 @@ from tvm_ffi.testing import (
 # ---------------------------------------------------------------------------
 
 
+def test_c_class_dataclass_transform_has_converter() -> None:
+    metadata = c_class.__dataclass_transform__  # ty: 
ignore[unresolved-attribute]
+    assert metadata["kwargs"]["converter"] is field().converter
+
+
 def test_c_class_custom_init() -> None:
     """c_class preserves user-defined __init__."""
     obj = _TestCxxClassBase(v_i64=10, v_i32=20)
diff --git a/tests/python/test_dataclass_py_class.py 
b/tests/python/test_dataclass_py_class.py
index 5a436144..49888aa5 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -53,6 +53,20 @@ def _get_type_info(cls: type) -> TypeInfo:
     return ret
 
 
+def test_py_class_dataclass_transform_has_converter() -> None:
+    metadata = py_class.__dataclass_transform__  # ty: 
ignore[unresolved-attribute]
+    assert metadata["kwargs"]["converter"] is field().converter
+
+
+def test_field_accepts_converter_metadata() -> None:
+    def converter(value: Any) -> Any:
+        return value
+
+    f = field(converter=converter)
+    assert isinstance(f, Field)
+    assert f.converter is converter
+
+
 # ---------------------------------------------------------------------------
 # Low-level helpers for _make_type-based tests
 # ---------------------------------------------------------------------------
diff --git a/tests/python/test_function.py b/tests/python/test_function.py
index 36291188..19c2517d 100644
--- a/tests/python/test_function.py
+++ b/tests/python/test_function.py
@@ -448,23 +448,25 @@ def test_convert_func_with_torch_tensor_cls() -> None:
     the outer caller's conversion path, so we verify shape survives the
     round-trip rather than isinstance on the return.
     """
+    assert torch is not None
+    torch_mod = torch
     calls = 0
 
     def callback(a: Any, b: Any, c: Any) -> Any:
         nonlocal calls
         calls += 1
-        assert isinstance(a, torch.Tensor)
-        assert isinstance(b, torch.Tensor)
-        assert isinstance(c, torch.Tensor)
+        assert isinstance(a, torch_mod.Tensor)
+        assert isinstance(b, torch_mod.Tensor)
+        assert isinstance(c, torch_mod.Tensor)
         assert list(a.shape) == [2]
         assert list(b.shape) == [3]
         assert list(c.shape) == [4]
         return b
 
-    f = tvm_ffi.convert_func(callback, tensor_cls=torch.Tensor)
-    a = torch.zeros(2)
-    b = torch.ones(3)
-    c = torch.full((4,), 2.0)
+    f = tvm_ffi.convert_func(callback, tensor_cls=torch_mod.Tensor)
+    a = torch_mod.zeros(2)
+    b = torch_mod.ones(3)
+    c = torch_mod.full((4,), 2.0)
     out = f(a, b, c)
     assert calls == 1
     assert tuple(out.shape) == (3,)
diff --git a/tests/python/test_stubgen.py b/tests/python/test_stubgen.py
index 5c00a04d..a82c0760 100644
--- a/tests/python/test_stubgen.py
+++ b/tests/python/test_stubgen.py
@@ -16,11 +16,15 @@
 # under the License.
 from __future__ import annotations
 
+import itertools
+import typing
 from pathlib import Path
 
 import pytest
 import tvm_ffi.stub.cli as stub_cli
+from tvm_ffi import Object, method
 from tvm_ffi.core import TypeSchema
+from tvm_ffi.dataclasses import py_class
 from tvm_ffi.stub import consts as C
 from tvm_ffi.stub.cli import _stage_2, _stage_3
 from tvm_ffi.stub.file_utils import CodeBlock, FileInfo
@@ -35,23 +39,32 @@ from tvm_ffi.stub.python_generator.codegen import (
     generate_python_init,
     generate_python_object,
     render_func_signature,
+    render_object_ffi_init,
     render_object_fields,
+    render_object_init,
     render_object_methods,
 )
 from tvm_ffi.stub.python_generator.utils import ImportItem
 from tvm_ffi.stub.utils import (
     FuncInfo,
     InitConfig,
+    InitFieldInfo,
     NamedTypeSchema,
     ObjectInfo,
     Options,
 )
 
+_counter = itertools.count()
+
 
 def _identity_ty_map(name: str) -> str:
     return name
 
 
+def _unique_type_key(base: str) -> str:
+    return f"testing.stubgen.{base}_{next(_counter)}"
+
+
 def _default_ty_map() -> dict[str, str]:
     return PC.TY_MAP_DEFAULTS.copy()
 
@@ -60,6 +73,10 @@ def _type_suffix(name: str) -> str:
     return PC.TY_MAP_DEFAULTS.get(name, name).rsplit(".", 1)[-1]
 
 
+def _input_type_suffix(name: str) -> str:
+    return PC.TY_MAP_INPUT_DEFAULTS.get(name, PC.TY_MAP_DEFAULTS.get(name, 
name)).rsplit(".", 1)[-1]
+
+
 def test_codeblock_from_begin_line_variants() -> None:
     cases = [
         (f"{C.PYTHON_SYNTAX.begin} global/demo", "global", ("demo", "")),
@@ -301,6 +318,126 @@ def test_objectinfo_gen_fields_container_types() -> None:
     ]
 
 
+def test_funcinfo_gen_uses_input_annotations_for_parameters() -> None:
+    info = FuncInfo(
+        schema=NamedTypeSchema(
+            "demo.echo_list",
+            TypeSchema(
+                "Callable",
+                (
+                    TypeSchema("List", (TypeSchema("int"),)),
+                    TypeSchema("List", (TypeSchema("int"),)),
+                ),
+            ),
+        ),
+        is_member=False,
+    )
+
+    assert (
+        render_func_signature(info, _type_suffix, indent=0, 
input_ty_map=_input_type_suffix)
+        == "def echo_list(_0: Sequence[int], /) -> MutableSequence[int]: ..."
+    )
+
+
+def test_generate_global_funcs_populates_input_defaults_for_partial_ty_map() 
-> None:
+    code = CodeBlock(
+        kind="global",
+        param=("demo", "mockpkg"),
+        lineno_start=1,
+        lineno_end=2,
+        lines=[f"{C.PYTHON_SYNTAX.begin} global/demo@mockpkg", 
C.PYTHON_SYNTAX.end],
+    )
+    funcs = [
+        FuncInfo(
+            schema=NamedTypeSchema(
+                "demo.echo_list",
+                TypeSchema(
+                    "Callable",
+                    (
+                        TypeSchema("List", (TypeSchema("int"),)),
+                        TypeSchema("List", (TypeSchema("int"),)),
+                    ),
+                ),
+            ),
+            is_member=False,
+        )
+    ]
+    imports: list[ImportItem] = []
+
+    generate_python_global_funcs(
+        code, funcs, {"List": "collections.abc.MutableSequence"}, imports, 
Options()
+    )
+
+    assert code.lines == [
+        f"{C.PYTHON_SYNTAX.begin} global/demo@mockpkg",
+        "# fmt: off",
+        '_FFI_INIT_FUNC("demo", __name__)',
+        "if TYPE_CHECKING:",
+        "    def echo_list(_0: Sequence[int], /) -> MutableSequence[int]: ...",
+        "# fmt: on",
+        C.PYTHON_SYNTAX.end,
+    ]
+
+
+def test_objectinfo_gen_init_uses_input_annotations() -> None:
+    info = ObjectInfo(
+        fields=[NamedTypeSchema("items", TypeSchema("List", 
(TypeSchema("int"),)))],
+        methods=[],
+        init_fields=[
+            InitFieldInfo(
+                name="items",
+                schema=NamedTypeSchema("items", TypeSchema("List", 
(TypeSchema("int"),))),
+                kw_only=False,
+                has_default=False,
+            )
+        ],
+        has_init=True,
+    )
+
+    assert render_object_fields(info, _type_suffix, indent=0) == ["items: 
MutableSequence[int]"]
+    assert render_object_init(info, _type_suffix, indent=0, 
input_ty_map=_input_type_suffix) == [
+        "def __init__(self, items: Sequence[int]) -> None: ..."
+    ]
+    assert render_object_ffi_init(
+        info, _type_suffix, indent=0, input_ty_map=_input_type_suffix
+    ) == [
+        "def __ffi_init__(self, items: Sequence[int]) -> None: ...  # ty: "
+        "ignore[invalid-method-override]"
+    ]
+
+
+def test_py_class_method_metadata_renders_stub_signature() -> None:
+    @py_class(_unique_type_key("MethodMetadata"))
+    class MethodMetadata(Object):
+        value: int
+
+        @method
+        def describe(self, values: typing.List[int], prefix: str) -> str:  # 
noqa: UP006
+            return f"{prefix}:{self.value}:{len(values)}"
+
+        @method
+        @staticmethod
+        def normalize(values: typing.List[int]) -> typing.List[int]:  # noqa: 
UP006
+            return values
+
+    info = ObjectInfo.from_type_info(MethodMetadata.__tvm_ffi_type_info__)  # 
ty: ignore[unresolved-attribute]
+    methods = {method.schema.name: method for method in info.methods}
+    describe_schema = methods["describe"].schema
+
+    assert describe_schema.origin == "Callable"
+    assert [arg.origin for arg in describe_schema.args] == [
+        "str",
+        MethodMetadata.__tvm_ffi_type_info__.type_key,  # ty: 
ignore[unresolved-attribute]
+        "List",
+        "str",
+    ]
+    assert render_object_methods(info, _type_suffix, indent=0, 
input_ty_map=_input_type_suffix) == [
+        "def describe(self, _1: Sequence[int], _2: str, /) -> str: ...",
+        "@staticmethod",
+        "def normalize(_0: Sequence[int], /) -> MutableSequence[int]: ...",
+    ]
+
+
 def test_generate_global_funcs_updates_block() -> None:
     code = CodeBlock(
         kind="global",
diff --git a/tests/python/test_type_converter.py 
b/tests/python/test_type_converter.py
index 82a0d7fb..a639e5d7 100644
--- a/tests/python/test_type_converter.py
+++ b/tests/python/test_type_converter.py
@@ -31,13 +31,14 @@ import pytest
 import tvm_ffi
 from tvm_ffi.core import (
     CAny,
+    Object,
     ObjectConvertible,
     TypeSchema,
     _lookup_type_attr,
     _object_type_key_to_index,
     _to_py_class_value,
 )
-from tvm_ffi.dataclasses import IntEnum, StrEnum, entry
+from tvm_ffi.dataclasses import IntEnum, StrEnum, entry, py_class
 from tvm_ffi.testing import (
     TestIntPair,
     TestObjectBase,
@@ -3443,6 +3444,54 @@ class TestSTLOriginParsing:
         assert s.origin == "Object"
 
 
+def _output_ty_map(name: str) -> str:
+    return {
+        "Array": "Sequence",
+        "List": "MutableSequence",
+        "Map": "Mapping",
+        "Dict": "MutableMapping",
+    }.get(name, name)
+
+
+def _input_ty_map(name: str) -> str:
+    return {
+        "Array": "Sequence",
+        "List": "Sequence",
+        "Map": "Mapping",
+        "Dict": "Mapping",
+    }.get(name, name)
+
+
+class TestTypeSchemaAnnotationRendering:
+    """Input annotations are widened without changing output annotations."""
+
+    def test_container_input_output_repr(self) -> None:
+        """List and Dict render differently for input and output positions."""
+        schema = S("Dict", S("str"), S("List", S("int")))
+
+        assert schema.output_repr(_output_ty_map) == "MutableMapping[str, 
MutableSequence[int]]"
+        assert schema.input_repr(_input_ty_map) == "Mapping[str, 
Sequence[int]]"
+        assert schema.repr(_output_ty_map) == 
schema.output_repr(_output_ty_map)
+
+    def test_object_convert_type_schema_attr_widens_input_only(self) -> None:
+        """__ffi_convert_type_schema__ affects input annotations only."""
+        type_key = _unique_type_key("ConvertTypeSchema")
+        convert_schema = (
+            
f'{{"type":"Union","args":[{{"type":"int"}},{{"type":"str"}},{{"type":"{type_key}"}}]}}'
+        )
+
+        @py_class(type_key)
+        class ExprLike(Object):
+            __ffi_convert_type_schema__ = convert_schema
+
+        def ty_map(name: str) -> str:
+            return "ExprLike" if name == type_key else name
+
+        schema = TypeSchema.from_annotation(ExprLike)
+        assert schema.output_repr(ty_map) == "ExprLike"
+        assert schema.input_repr(ty_map) == "int | str | ExprLike"
+
+
 # ---------------------------------------------------------------------------
 # Category 64: Zero-copy container conversion
 # ---------------------------------------------------------------------------


Reply via email to