This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 368af82  fix: `self` overlooked in schema of member functions (#91)
368af82 is described below

commit 368af824845424ea439b9f3d68bf4a710afb38b1
Author: Junru Shao <[email protected]>
AuthorDate: Tue Oct 7 04:47:55 2025 -0700

    fix: `self` overlooked in schema of member functions (#91)
    
    When generation function schema for member functions of a certain class,
    the existing logics overlooked `this` as the first argument, which, as a
    result, generates inproper number of arguments.
    
    More specifically, in the snippet below,
    
    ```C++
    class MyCls {
      void MyFunc(int a);
    };
    
    refl::GlobalDef()
        .def_method("MyCls_MyFunc", &MyCls::MyFunc);
    ```
    
    the global method `MyCls_MyFunc` is supposed to have signature `(MyCls,
    int) -> None`, but the implementation on mainline gives `(int) -> None`.
    
    This PR fixes this issue.
---
 include/tvm/ffi/function_details.h  |  4 ++--
 python/tvm_ffi/cython/type_info.pxi |  4 ++--
 src/ffi/extra/testing.cc            |  3 +++
 tests/cpp/test_metadata.cc          | 14 +++++++++-----
 tests/python/test_metadata.py       | 12 ++++++++----
 5 files changed, 24 insertions(+), 13 deletions(-)

diff --git a/include/tvm/ffi/function_details.h 
b/include/tvm/ffi/function_details.h
index ae9de33..01c0bf1 100644
--- a/include/tvm/ffi/function_details.h
+++ b/include/tvm/ffi/function_details.h
@@ -115,9 +115,9 @@ template <typename R, typename... Args>
 struct FunctionInfo<R (*)(Args...)> : FuncFunctorImpl<R, Args...> {};
 // Support pointer-to-member functions used in reflection (e.g. &Class::method)
 template <typename Class, typename R, typename... Args>
-struct FunctionInfo<R (Class::*)(Args...)> : FuncFunctorImpl<R, Args...> {};
+struct FunctionInfo<R (Class::*)(Args...)> : FuncFunctorImpl<R, Class*, 
Args...> {};
 template <typename Class, typename R, typename... Args>
-struct FunctionInfo<R (Class::*)(Args...) const> : FuncFunctorImpl<R, Args...> 
{};
+struct FunctionInfo<R (Class::*)(Args...) const> : FuncFunctorImpl<R, const 
Class*, Args...> {};
 
 /*! \brief Using static function to output typed function signature */
 typedef std::string (*FGetFuncSignature)();
diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
index 337e3db..4abfa4a 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -99,9 +99,9 @@ class TypeSchema:
         elif origin == "Optional":
             assert len(args) == 1, "Optional must have exactly one argument"
         elif origin == "list":
-            assert len(args) == 1, "list must have exactly one argument"
+            assert len(args) in (0, 1), "list must have 0 or 1 argument"
         elif origin == "dict":
-            assert len(args) == 2, "dict must have exactly two arguments"
+            assert len(args) in (0, 2), "dict must have 0 or 2 arguments"
         elif origin == "tuple":
             pass  # tuple can have arbitrary number of arguments
 
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index 555a14b..7263b43 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -373,6 +373,9 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       .def("testing.schema_id_string", [](String s) { return s; })
       .def("testing.schema_id_bytes", [](Bytes b) { return b; })
       .def("testing.schema_id_func", [](Function f) -> Function { return f; })
+      .def("testing.schema_id_func_typed",
+           [](TypedFunction<void(int64_t, float, Function)> f)
+               -> TypedFunction<void(int64_t, float, Function)> { return f; })
       .def("testing.schema_id_any", [](Any a) { return a; })
       .def("testing.schema_id_object", [](ObjectRef o) { return o; })
       .def("testing.schema_id_dltensor", [](DLTensor* t) { return t; })
diff --git a/tests/cpp/test_metadata.cc b/tests/cpp/test_metadata.cc
index e0fa054..b5fedb8 100644
--- a/tests/cpp/test_metadata.cc
+++ b/tests/cpp/test_metadata.cc
@@ -70,6 +70,9 @@ TEST(Schema, GlobalFuncTypeSchema) {
             
R"({"type":"ffi.Function","args":[{"type":"ffi.Bytes"},{"type":"ffi.Bytes"}]})");
   EXPECT_EQ(fetch("testing.schema_id_func"),
             
R"({"type":"ffi.Function","args":[{"type":"ffi.Function"},{"type":"ffi.Function"}]})");
+  EXPECT_EQ(
+      fetch("testing.schema_id_func_typed"),
+      
R"({"type":"ffi.Function","args":[{"type":"ffi.Function","args":[{"type":"None"},{"type":"int"},{"type":"float"},{"type":"ffi.Function"}]},{"type":"ffi.Function","args":[{"type":"None"},{"type":"int"},{"type":"float"},{"type":"ffi.Function"}]}]})");
 
   EXPECT_EQ(fetch("testing.schema_id_any"),
             
R"({"type":"ffi.Function","args":[{"type":"Any"},{"type":"Any"}]})");
@@ -173,17 +176,18 @@ TEST(Schema, MethodTypeSchemas) {
   };
 
   // Instance methods
-  EXPECT_EQ(method_schema("add_int"),
-            
R"({"type":"ffi.Function","args":[{"type":"int"},{"type":"int"}]})");
+  EXPECT_EQ(
+      method_schema("add_int"),
+      
R"({"type":"ffi.Function","args":[{"type":"int"},{"type":"testing.SchemaAllTypes"},{"type":"int"}]})");
   EXPECT_EQ(
       method_schema("append_int"),
-      
R"({"type":"ffi.Function","args":[{"type":"ffi.Array","args":[{"type":"int"}]},{"type":"ffi.Array","args":[{"type":"int"}]},{"type":"int"}]})");
+      
R"({"type":"ffi.Function","args":[{"type":"ffi.Array","args":[{"type":"int"}]},{"type":"testing.SchemaAllTypes"},{"type":"ffi.Array","args":[{"type":"int"}]},{"type":"int"}]})");
   EXPECT_EQ(
       method_schema("maybe_concat"),
-      
R"({"type":"ffi.Function","args":[{"type":"Optional","args":[{"type":"ffi.String"}]},{"type":"Optional","args":[{"type":"ffi.String"}]},{"type":"Optional","args":[{"type":"ffi.String"}]}]})");
+      
R"({"type":"ffi.Function","args":[{"type":"Optional","args":[{"type":"ffi.String"}]},{"type":"testing.SchemaAllTypes"},{"type":"Optional","args":[{"type":"ffi.String"}]},{"type":"Optional","args":[{"type":"ffi.String"}]}]})");
   EXPECT_EQ(
       method_schema("merge_map"),
-      
R"({"type":"ffi.Function","args":[{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]}]})");
+      
R"({"type":"ffi.Function","args":[{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"testing.SchemaAllTypes"},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]}]})");
 
   // Static method make_with: return type is the object type itself.
   // Build expected JSON as ffi.Function with return type = type_key and args 
= (int, float, str)
diff --git a/tests/python/test_metadata.py b/tests/python/test_metadata.py
index 527ffba..9f69bcc 100644
--- a/tests/python/test_metadata.py
+++ b/tests/python/test_metadata.py
@@ -33,6 +33,10 @@ from tvm_ffi.testing import _SchemaAllTypes
         ("testing.schema_id_string", "Callable[[str], str]"),
         ("testing.schema_id_bytes", "Callable[[bytes], bytes]"),
         ("testing.schema_id_func", "Callable[[Callable[..., Any]], 
Callable[..., Any]]"),
+        (
+            "testing.schema_id_func_typed",
+            "Callable[[Callable[[int, float, Callable[..., Any]], None]], 
Callable[[int, float, Callable[..., Any]], None]]",
+        ),
         ("testing.schema_id_any", "Callable[[Any], Any]"),
         ("testing.schema_id_object", "Callable[[Object], Object]"),
         ("testing.schema_id_dltensor", "Callable[[Tensor], Tensor]"),
@@ -99,12 +103,12 @@ def test_schema_field(field_name: str, expected: str) -> 
None:
 @pytest.mark.parametrize(
     "method_name,expected",
     [
-        ("add_int", "Callable[[int], int]"),
-        ("append_int", "Callable[[list[int], int], list[int]]"),
-        ("maybe_concat", "Callable[[str | None, str | None], str | None]"),
+        ("add_int", "Callable[[testing.SchemaAllTypes, int], int]"),
+        ("append_int", "Callable[[testing.SchemaAllTypes, list[int], int], 
list[int]]"),
+        ("maybe_concat", "Callable[[testing.SchemaAllTypes, str | None, str | 
None], str | None]"),
         (
             "merge_map",
-            "Callable[[dict[str, list[int]], dict[str, list[int]]], dict[str, 
list[int]]]",
+            "Callable[[testing.SchemaAllTypes, dict[str, list[int]], dict[str, 
list[int]]], dict[str, list[int]]]",
         ),
         ("make_with", "Callable[[int, float, str], testing.SchemaAllTypes]"),
     ],

Reply via email to