This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 6f1b459 [FEAT][REFLECTION] Add tvm::ffi::reflection::overload_cast to
pick overloaded function (#582)
6f1b459 is described below
commit 6f1b459532bd54c942cda3d932ea4f3f629c6df0
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun May 10 17:36:25 2026 -0400
[FEAT][REFLECTION] Add tvm::ffi::reflection::overload_cast to pick
overloaded function (#582)
## Summary
Introduce `tvm::ffi::reflection::overload_cast<Args...>` — a `constexpr`
helper for picking a specific overload of an overloaded callable by
spelling out a parameter-type prefix that uniquely identifies it.
Trailing parameter types (if any) are deduced from the picked overload's
signature. The result is a typed function pointer (member or free)
usable wherever a typed fn ptr is required, including as a non-type
template argument.
If the prefix matches multiple overloads, the call is ambiguous and the
caller spells more parameters until exactly one overload matches.
Const-qualified members are picked via the
`tvm::ffi::reflection::const_` tag.
---
include/tvm/ffi/reflection/registry.h | 117 ++++++++++++++++++++++++++++++++++
tests/cpp/test_reflection.cc | 108 +++++++++++++++++++++++++++++++
2 files changed, 225 insertions(+)
diff --git a/include/tvm/ffi/reflection/registry.h
b/include/tvm/ffi/reflection/registry.h
index 3e715fe..543321c 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -1034,6 +1034,123 @@ inline void EnsureTypeAttrColumn(std::string_view name)
{
reinterpret_cast<const
TVMFFIAny*>(&any_view)));
}
+/// \cond Doxygen_Suppress
+namespace details {
+
+/*!
+ * \brief Implementation struct for overload_cast.
+ *
+ * Provides operator() overloads for each callable kind (free function,
+ * non-const member, const member), in two flavors: full match where Args...
+ * is the entire parameter list, and prefix match where Args... is a leading
+ * prefix and the trailing parameters Rest... are deduced from the picked
+ * overload's signature.
+ */
+template <typename... Args>
+struct OverloadCastImpl {
+ // The first triplet handles the case where Args... is the complete
+ // parameter list of the picked overload. The second triplet handles
+ // the prefix-match case where the picked overload has additional
+ // trailing parameters Rest... beyond Args...; partial ordering picks
+ // the first triplet when both apply, which lets the caller
+ // disambiguate against shared-prefix overload sets by spelling the
+ // full parameter list.
+
+ template <typename Ret>
+ constexpr auto operator()(Ret (*fn)(Args...)) const noexcept {
+ return fn;
+ }
+ template <typename Ret, typename Cls>
+ constexpr auto operator()(Ret (Cls::*pmf)(Args...), std::false_type = {})
const noexcept {
+ return pmf;
+ }
+ template <typename Ret, typename Cls>
+ constexpr auto operator()(Ret (Cls::*pmf)(Args...) const, std::true_type)
const noexcept {
+ return pmf;
+ }
+
+ template <typename Ret, typename... Rest>
+ constexpr auto operator()(Ret (*fn)(Args..., Rest...)) const noexcept {
+ return fn;
+ }
+ template <typename Ret, typename Cls, typename... Rest>
+ constexpr auto operator()(Ret (Cls::*pmf)(Args..., Rest...),
+ std::false_type = {}) const noexcept {
+ return pmf;
+ }
+ template <typename Ret, typename Cls, typename... Rest>
+ constexpr auto operator()(Ret (Cls::*pmf)(Args..., Rest...) const,
+ std::true_type) const noexcept {
+ return pmf;
+ }
+};
+
+} // namespace details
+/// \endcond
+
+/*!
+ * \brief Cast an overloaded callable to a specific overload, picked by
+ * spelling out a parameter-type prefix that uniquely identifies it.
+ *
+ * `Args...` is matched against the leading parameters of each candidate
+ * overload; the trailing parameter types (if any) are deduced from the
+ * picked overload's signature. The returned value is a constexpr function
+ * pointer (member or free) and can be used wherever a typed function
+ * pointer is required, including as a non-type template argument.
+ *
+ * If the prefix matches multiple overloads (e.g. two overloads share the
+ * same leading parameters), the call is ambiguous and the caller must
+ * spell more parameters until exactly one overload matches.
+ *
+ * \note When picking a const-qualified member function, `refl::const_` must
+ * be passed as the second argument even when it is the only overload
+ * of its name. Without the tag the call does not compile.
+ *
+ * \note This helper can be more permissive than some `overload_cast` variants
+ * in existing packages that require the full parameter list to be
+ * spelled out: here a parameter-type prefix is accepted and the
+ * trailing types are deduced from the picked overload.
+ *
+ * \code{.cpp}
+ * class Pet {
+ * public:
+ * void Set(int);
+ * void Set(const std::string&);
+ * int Feed(const Cat*, int amount);
+ * int Feed(const Dog*, int amount);
+ * int Get(int);
+ * int Get(int) const;
+ * };
+ *
+ * namespace refl = tvm::ffi::reflection;
+ *
+ * // Spell only the disambiguating first arg; the trailing `int amount`
+ * // is deduced from the picked overload's signature.
+ * auto p_feed_cat = refl::overload_cast<const Cat*>(&Pet::Feed);
+ * // decltype(p_feed_cat) == int (Pet::*)(const Cat*, int)
+ *
+ * // Spell the full parameter list when overloads share a prefix.
+ * auto p_set_int = refl::overload_cast<int>(&Pet::Set);
+ *
+ * // Const-qualified member — opt in via the const_ tag:
+ * auto p_get_const = refl::overload_cast<int>(&Pet::Get, refl::const_);
+ *
+ * // Use directly as a non-type template argument:
+ * template <auto F> struct UseAsTemplateArg { ... };
+ * using U = UseAsTemplateArg<refl::overload_cast<const Cat*>(&Pet::Feed)>;
+ * \endcode
+ */
+template <typename... Args>
+inline constexpr details::OverloadCastImpl<Args...> overload_cast = {};
+
+/// \cond Doxygen_Suppress
+// `const_`'s trailing underscore triggers RST hyperlink-reference syntax in
+// the exhale-generated per-variable page; suppress doc emission for it.
+// The symbol is still referenced (and rendered as inline literal) from the
+// overload_cast docstring above.
+inline constexpr auto const_ = std::true_type{};
+/// \endcond
+
} // namespace reflection
} // namespace ffi
} // namespace tvm
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index f9d567f..0593eca 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -625,4 +625,112 @@ TEST(Reflection, AutoInitSimpleTooManyArgs) {
EXPECT_THROW(auto_init(int64_t{1}, int64_t{2}, int64_t{3}), std::exception);
}
+// ---------------------------------------------------------------------------
+// overload_cast — pick an overload by prefix-matching its parameter types.
+// ---------------------------------------------------------------------------
+
+namespace overload_cast_test {
+
+struct Cat {};
+struct Dog {};
+
+// Pet: each Feed overload has a unique first arg plus a trailing context
+// (`int amount`) that the caller doesn't have to spell out. Get is const
+// vs non-const overloaded with identical params (selected via const_ tag).
+struct Pet {
+ int Feed(const Cat*, int amount) { return 100 + amount; }
+ int Feed(const Dog*, int amount) { return 200 + amount; }
+ int Get(int x) { return 4000 + x; }
+ int Get(int x) const { return 5000 + x; }
+};
+
+// Mix: overloads share a leading prefix — spelling more parameters
+// disambiguates against the longer variant.
+struct Mix {
+ int Run(int, int) { return 1000; }
+ int Run(int, double) { return 2000; }
+ int Run(int, int, int) { return 3000; }
+};
+
+int FreeFeed(const Cat*, int x) { return 6000 + x; }
+int FreeFeed(const Dog*, int x) { return 7000 + x; }
+
+template <auto Method>
+struct CallVia {
+ template <typename Self, typename... Args>
+ static auto Run(Self&& self, Args&&... args) {
+ return (std::forward<Self>(self).*Method)(std::forward<Args>(args)...);
+ }
+};
+
+} // namespace overload_cast_test
+
+TEST(OverloadCast, PrefixMatch) {
+ using namespace overload_cast_test;
+ namespace refl = tvm::ffi::reflection;
+ Pet p;
+ Cat cat;
+ Dog dog;
+
+ // (a) Member with unique first arg per overload: spelling only the
+ // disambiguating prefix picks the overload and deduces the
+ // trailing `int amount` from the picked signature.
+ auto p_cat = refl::overload_cast<const Cat*>(&Pet::Feed);
+ static_assert(std::is_same_v<decltype(p_cat), int (Pet::*)(const Cat*, int)>,
+ "prefix match must deduce trailing arg types");
+ EXPECT_EQ((p.*p_cat)(&cat, 7), 107);
+
+ auto p_dog = refl::overload_cast<const Dog*>(&Pet::Feed);
+ EXPECT_EQ((p.*p_dog)(&dog, 12), 212);
+
+ // (b) Free function with the same shape — trailing arg deduced.
+ auto p_free_cat = refl::overload_cast<const Cat*>(&FreeFeed);
+ EXPECT_EQ(p_free_cat(&cat, 7), 6007);
+ auto p_free_dog = refl::overload_cast<const Dog*>(&FreeFeed);
+ EXPECT_EQ(p_free_dog(&dog, 7), 7007);
+}
+
+TEST(OverloadCast, AmbiguousPrefixRequiresMoreSpelling) {
+ using namespace overload_cast_test;
+ namespace refl = tvm::ffi::reflection;
+ Mix m;
+
+ // Mix::Run has three overloads:
+ // Run(int, int), Run(int, double), Run(int, int, int)
+ // Spelling only <int> would be ambiguous (all three start with int).
+ // Spelling enough parameters to identify exactly one overload picks it.
+ EXPECT_EQ((m.*refl::overload_cast<int, int>(&Mix::Run))(0, 0), 1000);
+ EXPECT_EQ((m.*refl::overload_cast<int, double>(&Mix::Run))(0, 0.0), 2000);
+ EXPECT_EQ((m.*refl::overload_cast<int, int, int>(&Mix::Run))(0, 0, 0), 3000);
+}
+
+TEST(OverloadCast, ConstQualifiedMember) {
+ using namespace overload_cast_test;
+ namespace refl = tvm::ffi::reflection;
+ Pet p;
+ const Pet& cp = p;
+
+ // Non-const overload — no tag.
+ EXPECT_EQ((p.*refl::overload_cast<int>(&Pet::Get))(7), 4007);
+
+ // Const overload — const_ tag required (even when the const overload
+ // is the only one with that signature, address-of-overload alone
+ // cannot select it from the operator() overload set).
+ EXPECT_EQ((cp.*refl::overload_cast<int>(&Pet::Get, refl::const_))(7), 5007);
+}
+
+TEST(OverloadCast, NonTypeTemplateArgument) {
+ using namespace overload_cast_test;
+ namespace refl = tvm::ffi::reflection;
+ Pet p;
+ Mix m;
+ Cat cat;
+
+ // Prefix match composed as a non-type template argument.
+ EXPECT_EQ((CallVia<refl::overload_cast<const Cat*>(&Pet::Feed)>::Run(p,
&cat, 7)), 107);
+
+ // Disambiguated 3-arg overload as a non-type template argument.
+ EXPECT_EQ((CallVia<refl::overload_cast<int, int, int>(&Mix::Run)>::Run(m, 0,
0, 0)), 3000);
+}
+
} // namespace