This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch overload-cast-helper in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
commit 056b62b0375a7661fde1bfddae27be50be6bdc2b Author: tqchen <[email protected]> AuthorDate: Sun May 10 15:52:36 2026 +0000 [FEAT] Add tvm::ffi::reflection::overload_cast helper Introduce a constexpr helper for picking a specific overload of an overloaded callable by spelling out a parameter-type prefix that uniquely identifies it. Trailing parameter types (if any) are deduced from the picked overload's signature. The result is a typed function pointer (member or free) usable wherever a typed fn ptr is required, including as a non-type template argument. If the prefix matches multiple overloads, the call is ambiguous and the caller must spell more parameters until exactly one overload matches. Const-qualified members are picked via the tvm::ffi::reflection::const_ tag. Place in tvm::ffi::reflection alongside AttachFieldFlag and other registration-time helpers. Tests cover prefix matching, shared-prefix disambiguation, const qualification, and use as a non-type template argument. --- include/tvm/ffi/reflection/registry.h | 105 +++++++++++++++++++++++++++++++++ tests/cpp/test_reflection.cc | 108 ++++++++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+) diff --git a/include/tvm/ffi/reflection/registry.h b/include/tvm/ffi/reflection/registry.h index 3e715fe..b7498c6 100644 --- a/include/tvm/ffi/reflection/registry.h +++ b/include/tvm/ffi/reflection/registry.h @@ -1034,6 +1034,111 @@ inline void EnsureTypeAttrColumn(std::string_view name) { reinterpret_cast<const TVMFFIAny*>(&any_view))); } +namespace details { + +/*! + * \brief Implementation struct for overload_cast. + * + * Provides operator() overloads for each callable kind (free function, + * non-const member, const member), in two flavors: full match where Args... + * is the entire parameter list, and prefix match where Args... is a leading + * prefix and the trailing parameters Rest... are deduced from the picked + * overload's signature. + */ +template <typename... Args> +struct OverloadCastImpl { + // The first triplet handles the case where Args... is the complete + // parameter list of the picked overload. The second triplet handles + // the prefix-match case where the picked overload has additional + // trailing parameters Rest... beyond Args...; partial ordering picks + // the first triplet when both apply, which lets the caller + // disambiguate against shared-prefix overload sets by spelling the + // full parameter list. + + template <typename Ret> + constexpr auto operator()(Ret (*fn)(Args...)) const noexcept { + return fn; + } + template <typename Ret, typename Cls> + constexpr auto operator()(Ret (Cls::*pmf)(Args...), std::false_type = {}) const noexcept { + return pmf; + } + template <typename Ret, typename Cls> + constexpr auto operator()(Ret (Cls::*pmf)(Args...) const, std::true_type) const noexcept { + return pmf; + } + + template <typename Ret, typename... Rest> + constexpr auto operator()(Ret (*fn)(Args..., Rest...)) const noexcept { + return fn; + } + template <typename Ret, typename Cls, typename... Rest> + constexpr auto operator()(Ret (Cls::*pmf)(Args..., Rest...), + std::false_type = {}) const noexcept { + return pmf; + } + template <typename Ret, typename Cls, typename... Rest> + constexpr auto operator()(Ret (Cls::*pmf)(Args..., Rest...) const, + std::true_type) const noexcept { + return pmf; + } +}; + +} // namespace details + +/*! + * \brief Cast an overloaded callable to a specific overload, picked by + * spelling out a parameter-type prefix that uniquely identifies it. + * + * `Args...` is matched against the leading parameters of each candidate + * overload; the trailing parameter types (if any) are deduced from the + * picked overload's signature. The returned value is a constexpr function + * pointer (member or free) and can be used wherever a typed function + * pointer is required, including as a non-type template argument. + * + * If the prefix matches multiple overloads (e.g. two overloads share the + * same leading parameters), the call is ambiguous and the caller must + * spell more parameters until exactly one overload matches. + * + * \note When picking a const-qualified member function, `refl::const_` must + * be passed as the second argument even when it is the only overload + * of its name. Without the tag the call does not compile. + * + * \code{.cpp} + * class Pet { + * public: + * void Set(int); + * void Set(const std::string&); + * int Feed(const Cat*, int amount); + * int Feed(const Dog*, int amount); + * int Get(int); + * int Get(int) const; + * }; + * + * namespace refl = tvm::ffi::reflection; + * + * // Spell only the disambiguating first arg; the trailing `int amount` + * // is deduced from the picked overload's signature. + * auto p_feed_cat = refl::overload_cast<const Cat*>(&Pet::Feed); + * // decltype(p_feed_cat) == int (Pet::*)(const Cat*, int) + * + * // Spell the full parameter list when overloads share a prefix. + * auto p_set_int = refl::overload_cast<int>(&Pet::Set); + * + * // Const-qualified member — opt in via the const_ tag: + * auto p_get_const = refl::overload_cast<int>(&Pet::Get, refl::const_); + * + * // Use directly as a non-type template argument: + * template <auto F> struct UseAsTemplateArg { ... }; + * using U = UseAsTemplateArg<refl::overload_cast<const Cat*>(&Pet::Feed)>; + * \endcode + */ +template <typename... Args> +inline constexpr details::OverloadCastImpl<Args...> overload_cast = {}; + +/*! \brief Tag used to select the const-qualified overload in overload_cast. */ +inline constexpr auto const_ = std::true_type{}; + } // namespace reflection } // namespace ffi } // namespace tvm diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc index f9d567f..0593eca 100644 --- a/tests/cpp/test_reflection.cc +++ b/tests/cpp/test_reflection.cc @@ -625,4 +625,112 @@ TEST(Reflection, AutoInitSimpleTooManyArgs) { EXPECT_THROW(auto_init(int64_t{1}, int64_t{2}, int64_t{3}), std::exception); } +// --------------------------------------------------------------------------- +// overload_cast — pick an overload by prefix-matching its parameter types. +// --------------------------------------------------------------------------- + +namespace overload_cast_test { + +struct Cat {}; +struct Dog {}; + +// Pet: each Feed overload has a unique first arg plus a trailing context +// (`int amount`) that the caller doesn't have to spell out. Get is const +// vs non-const overloaded with identical params (selected via const_ tag). +struct Pet { + int Feed(const Cat*, int amount) { return 100 + amount; } + int Feed(const Dog*, int amount) { return 200 + amount; } + int Get(int x) { return 4000 + x; } + int Get(int x) const { return 5000 + x; } +}; + +// Mix: overloads share a leading prefix — spelling more parameters +// disambiguates against the longer variant. +struct Mix { + int Run(int, int) { return 1000; } + int Run(int, double) { return 2000; } + int Run(int, int, int) { return 3000; } +}; + +int FreeFeed(const Cat*, int x) { return 6000 + x; } +int FreeFeed(const Dog*, int x) { return 7000 + x; } + +template <auto Method> +struct CallVia { + template <typename Self, typename... Args> + static auto Run(Self&& self, Args&&... args) { + return (std::forward<Self>(self).*Method)(std::forward<Args>(args)...); + } +}; + +} // namespace overload_cast_test + +TEST(OverloadCast, PrefixMatch) { + using namespace overload_cast_test; + namespace refl = tvm::ffi::reflection; + Pet p; + Cat cat; + Dog dog; + + // (a) Member with unique first arg per overload: spelling only the + // disambiguating prefix picks the overload and deduces the + // trailing `int amount` from the picked signature. + auto p_cat = refl::overload_cast<const Cat*>(&Pet::Feed); + static_assert(std::is_same_v<decltype(p_cat), int (Pet::*)(const Cat*, int)>, + "prefix match must deduce trailing arg types"); + EXPECT_EQ((p.*p_cat)(&cat, 7), 107); + + auto p_dog = refl::overload_cast<const Dog*>(&Pet::Feed); + EXPECT_EQ((p.*p_dog)(&dog, 12), 212); + + // (b) Free function with the same shape — trailing arg deduced. + auto p_free_cat = refl::overload_cast<const Cat*>(&FreeFeed); + EXPECT_EQ(p_free_cat(&cat, 7), 6007); + auto p_free_dog = refl::overload_cast<const Dog*>(&FreeFeed); + EXPECT_EQ(p_free_dog(&dog, 7), 7007); +} + +TEST(OverloadCast, AmbiguousPrefixRequiresMoreSpelling) { + using namespace overload_cast_test; + namespace refl = tvm::ffi::reflection; + Mix m; + + // Mix::Run has three overloads: + // Run(int, int), Run(int, double), Run(int, int, int) + // Spelling only <int> would be ambiguous (all three start with int). + // Spelling enough parameters to identify exactly one overload picks it. + EXPECT_EQ((m.*refl::overload_cast<int, int>(&Mix::Run))(0, 0), 1000); + EXPECT_EQ((m.*refl::overload_cast<int, double>(&Mix::Run))(0, 0.0), 2000); + EXPECT_EQ((m.*refl::overload_cast<int, int, int>(&Mix::Run))(0, 0, 0), 3000); +} + +TEST(OverloadCast, ConstQualifiedMember) { + using namespace overload_cast_test; + namespace refl = tvm::ffi::reflection; + Pet p; + const Pet& cp = p; + + // Non-const overload — no tag. + EXPECT_EQ((p.*refl::overload_cast<int>(&Pet::Get))(7), 4007); + + // Const overload — const_ tag required (even when the const overload + // is the only one with that signature, address-of-overload alone + // cannot select it from the operator() overload set). + EXPECT_EQ((cp.*refl::overload_cast<int>(&Pet::Get, refl::const_))(7), 5007); +} + +TEST(OverloadCast, NonTypeTemplateArgument) { + using namespace overload_cast_test; + namespace refl = tvm::ffi::reflection; + Pet p; + Mix m; + Cat cat; + + // Prefix match composed as a non-type template argument. + EXPECT_EQ((CallVia<refl::overload_cast<const Cat*>(&Pet::Feed)>::Run(p, &cat, 7)), 107); + + // Disambiguated 3-arg overload as a non-type template argument. + EXPECT_EQ((CallVia<refl::overload_cast<int, int, int>(&Mix::Run)>::Run(m, 0, 0, 0)), 3000); +} + } // namespace
