This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 368af82 fix: `self` overlooked in schema of member functions (#91)
368af82 is described below
commit 368af824845424ea439b9f3d68bf4a710afb38b1
Author: Junru Shao <[email protected]>
AuthorDate: Tue Oct 7 04:47:55 2025 -0700
fix: `self` overlooked in schema of member functions (#91)
When generation function schema for member functions of a certain class,
the existing logics overlooked `this` as the first argument, which, as a
result, generates inproper number of arguments.
More specifically, in the snippet below,
```C++
class MyCls {
void MyFunc(int a);
};
refl::GlobalDef()
.def_method("MyCls_MyFunc", &MyCls::MyFunc);
```
the global method `MyCls_MyFunc` is supposed to have signature `(MyCls,
int) -> None`, but the implementation on mainline gives `(int) -> None`.
This PR fixes this issue.
---
include/tvm/ffi/function_details.h | 4 ++--
python/tvm_ffi/cython/type_info.pxi | 4 ++--
src/ffi/extra/testing.cc | 3 +++
tests/cpp/test_metadata.cc | 14 +++++++++-----
tests/python/test_metadata.py | 12 ++++++++----
5 files changed, 24 insertions(+), 13 deletions(-)
diff --git a/include/tvm/ffi/function_details.h
b/include/tvm/ffi/function_details.h
index ae9de33..01c0bf1 100644
--- a/include/tvm/ffi/function_details.h
+++ b/include/tvm/ffi/function_details.h
@@ -115,9 +115,9 @@ template <typename R, typename... Args>
struct FunctionInfo<R (*)(Args...)> : FuncFunctorImpl<R, Args...> {};
// Support pointer-to-member functions used in reflection (e.g. &Class::method)
template <typename Class, typename R, typename... Args>
-struct FunctionInfo<R (Class::*)(Args...)> : FuncFunctorImpl<R, Args...> {};
+struct FunctionInfo<R (Class::*)(Args...)> : FuncFunctorImpl<R, Class*,
Args...> {};
template <typename Class, typename R, typename... Args>
-struct FunctionInfo<R (Class::*)(Args...) const> : FuncFunctorImpl<R, Args...>
{};
+struct FunctionInfo<R (Class::*)(Args...) const> : FuncFunctorImpl<R, const
Class*, Args...> {};
/*! \brief Using static function to output typed function signature */
typedef std::string (*FGetFuncSignature)();
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 337e3db..4abfa4a 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -99,9 +99,9 @@ class TypeSchema:
elif origin == "Optional":
assert len(args) == 1, "Optional must have exactly one argument"
elif origin == "list":
- assert len(args) == 1, "list must have exactly one argument"
+ assert len(args) in (0, 1), "list must have 0 or 1 argument"
elif origin == "dict":
- assert len(args) == 2, "dict must have exactly two arguments"
+ assert len(args) in (0, 2), "dict must have 0 or 2 arguments"
elif origin == "tuple":
pass # tuple can have arbitrary number of arguments
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index 555a14b..7263b43 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -373,6 +373,9 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def("testing.schema_id_string", [](String s) { return s; })
.def("testing.schema_id_bytes", [](Bytes b) { return b; })
.def("testing.schema_id_func", [](Function f) -> Function { return f; })
+ .def("testing.schema_id_func_typed",
+ [](TypedFunction<void(int64_t, float, Function)> f)
+ -> TypedFunction<void(int64_t, float, Function)> { return f; })
.def("testing.schema_id_any", [](Any a) { return a; })
.def("testing.schema_id_object", [](ObjectRef o) { return o; })
.def("testing.schema_id_dltensor", [](DLTensor* t) { return t; })
diff --git a/tests/cpp/test_metadata.cc b/tests/cpp/test_metadata.cc
index e0fa054..b5fedb8 100644
--- a/tests/cpp/test_metadata.cc
+++ b/tests/cpp/test_metadata.cc
@@ -70,6 +70,9 @@ TEST(Schema, GlobalFuncTypeSchema) {
R"({"type":"ffi.Function","args":[{"type":"ffi.Bytes"},{"type":"ffi.Bytes"}]})");
EXPECT_EQ(fetch("testing.schema_id_func"),
R"({"type":"ffi.Function","args":[{"type":"ffi.Function"},{"type":"ffi.Function"}]})");
+ EXPECT_EQ(
+ fetch("testing.schema_id_func_typed"),
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Function","args":[{"type":"None"},{"type":"int"},{"type":"float"},{"type":"ffi.Function"}]},{"type":"ffi.Function","args":[{"type":"None"},{"type":"int"},{"type":"float"},{"type":"ffi.Function"}]}]})");
EXPECT_EQ(fetch("testing.schema_id_any"),
R"({"type":"ffi.Function","args":[{"type":"Any"},{"type":"Any"}]})");
@@ -173,17 +176,18 @@ TEST(Schema, MethodTypeSchemas) {
};
// Instance methods
- EXPECT_EQ(method_schema("add_int"),
-
R"({"type":"ffi.Function","args":[{"type":"int"},{"type":"int"}]})");
+ EXPECT_EQ(
+ method_schema("add_int"),
+
R"({"type":"ffi.Function","args":[{"type":"int"},{"type":"testing.SchemaAllTypes"},{"type":"int"}]})");
EXPECT_EQ(
method_schema("append_int"),
-
R"({"type":"ffi.Function","args":[{"type":"ffi.Array","args":[{"type":"int"}]},{"type":"ffi.Array","args":[{"type":"int"}]},{"type":"int"}]})");
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Array","args":[{"type":"int"}]},{"type":"testing.SchemaAllTypes"},{"type":"ffi.Array","args":[{"type":"int"}]},{"type":"int"}]})");
EXPECT_EQ(
method_schema("maybe_concat"),
-
R"({"type":"ffi.Function","args":[{"type":"Optional","args":[{"type":"ffi.String"}]},{"type":"Optional","args":[{"type":"ffi.String"}]},{"type":"Optional","args":[{"type":"ffi.String"}]}]})");
+
R"({"type":"ffi.Function","args":[{"type":"Optional","args":[{"type":"ffi.String"}]},{"type":"testing.SchemaAllTypes"},{"type":"Optional","args":[{"type":"ffi.String"}]},{"type":"Optional","args":[{"type":"ffi.String"}]}]})");
EXPECT_EQ(
method_schema("merge_map"),
-
R"({"type":"ffi.Function","args":[{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]}]})");
+
R"({"type":"ffi.Function","args":[{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"testing.SchemaAllTypes"},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]},{"type":"ffi.Map","args":[{"type":"ffi.String"},{"type":"ffi.Array","args":[{"type":"int"}]}]}]})");
// Static method make_with: return type is the object type itself.
// Build expected JSON as ffi.Function with return type = type_key and args
= (int, float, str)
diff --git a/tests/python/test_metadata.py b/tests/python/test_metadata.py
index 527ffba..9f69bcc 100644
--- a/tests/python/test_metadata.py
+++ b/tests/python/test_metadata.py
@@ -33,6 +33,10 @@ from tvm_ffi.testing import _SchemaAllTypes
("testing.schema_id_string", "Callable[[str], str]"),
("testing.schema_id_bytes", "Callable[[bytes], bytes]"),
("testing.schema_id_func", "Callable[[Callable[..., Any]],
Callable[..., Any]]"),
+ (
+ "testing.schema_id_func_typed",
+ "Callable[[Callable[[int, float, Callable[..., Any]], None]],
Callable[[int, float, Callable[..., Any]], None]]",
+ ),
("testing.schema_id_any", "Callable[[Any], Any]"),
("testing.schema_id_object", "Callable[[Object], Object]"),
("testing.schema_id_dltensor", "Callable[[Tensor], Tensor]"),
@@ -99,12 +103,12 @@ def test_schema_field(field_name: str, expected: str) ->
None:
@pytest.mark.parametrize(
"method_name,expected",
[
- ("add_int", "Callable[[int], int]"),
- ("append_int", "Callable[[list[int], int], list[int]]"),
- ("maybe_concat", "Callable[[str | None, str | None], str | None]"),
+ ("add_int", "Callable[[testing.SchemaAllTypes, int], int]"),
+ ("append_int", "Callable[[testing.SchemaAllTypes, list[int], int],
list[int]]"),
+ ("maybe_concat", "Callable[[testing.SchemaAllTypes, str | None, str |
None], str | None]"),
(
"merge_map",
- "Callable[[dict[str, list[int]], dict[str, list[int]]], dict[str,
list[int]]]",
+ "Callable[[testing.SchemaAllTypes, dict[str, list[int]], dict[str,
list[int]]], dict[str, list[int]]]",
),
("make_with", "Callable[[int, float, str], testing.SchemaAllTypes]"),
],