This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 84c5bdb  [Feature] Support dynamic-style overload for FFI object types 
(#286)
84c5bdb is described below

commit 84c5bdbcd11ae0e5bc921db4f35d3fd1d3ae762c
Author: DarkSharpness <[email protected]>
AuthorDate: Tue Dec 23 14:45:32 2025 +0800

    [Feature] Support dynamic-style overload for FFI object types (#286)
    
    Related discussion here #265 .
    
    Modification in short:
    
    1. Add `Function::FromPackedInplace` in `function.h` and generalize some
    methods in `registry.h`
    2. Implement FFI object types overload method (all in
    `extra/overload.h`)
    3. Add a simple test.
---
 include/tvm/ffi/function.h            |  47 +++-
 include/tvm/ffi/reflection/overload.h | 501 ++++++++++++++++++++++++++++++++++
 include/tvm/ffi/reflection/registry.h |  43 +--
 tests/cpp/test_overload.cc            |  95 +++++++
 4 files changed, 658 insertions(+), 28 deletions(-)

diff --git a/include/tvm/ffi/function.h b/include/tvm/ffi/function.h
index d1cc693..3043731 100644
--- a/include/tvm/ffi/function.h
+++ b/include/tvm/ffi/function.h
@@ -33,6 +33,10 @@
 #define TVM_FFI_DLL_EXPORT_INCLUDE_METADATA 0
 #endif
 
+#if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA
+#include <sstream>
+#endif  // TVM_FFI_DLL_EXPORT_INCLUDE_METADATA
+
 #include <tvm/ffi/any.h>
 #include <tvm/ffi/base_details.h>
 #include <tvm/ffi/c_api.h>
@@ -40,7 +44,9 @@
 #include <tvm/ffi/function_details.h>
 
 #include <functional>
+#include <optional>
 #include <string>
+#include <tuple>
 #include <type_traits>
 #include <utility>
 #include <vector>
@@ -165,21 +171,19 @@ class FunctionObjImpl : public FunctionObj {
 
   /*!
    * \brief Derived object class for constructing ffi::FunctionObj.
-   * \param callable The type-erased callable object (rvalue).
-   */
-  explicit FunctionObjImpl(TCallable&& callable) : 
callable_(std::move(callable)) {
-    this->safe_call = SafeCall;
-    this->cpp_call = reinterpret_cast<void*>(CppCall);
-  }
-  /*!
-   * \brief Derived object class for constructing ffi::FunctionObj.
-   * \param callable The type-erased callable object (lvalue).
+   * \param args The arguments to construct TCallable
    */
-  explicit FunctionObjImpl(const TCallable& callable) : callable_(callable) {
+  template <typename... Args>
+  explicit FunctionObjImpl(Args&&... args) : 
callable_(std::forward<Args>(args)...) {
     this->safe_call = SafeCall;
     this->cpp_call = reinterpret_cast<void*>(CppCall);
   }
 
+  FunctionObjImpl(const FunctionObjImpl&) = delete;
+  FunctionObjImpl& operator=(const FunctionObjImpl&) = delete;
+
+  TCallable* GetCallable() { return &callable_; }
+
  private:
   // implementation of call
   static void CppCall(const FunctionObj* func, const AnyView* args, int32_t 
num_args, Any* result) {
@@ -356,6 +360,29 @@ class Function : public ObjectRef {
     }
   }
 
+  /*!
+   * \brief Constructing a packed function from a callable type
+   *        whose signature is consistent with `ffi::Function`.
+   *        It will create the Callable object with the given arguments,
+   *        and return the inplace constructed Function along with
+   *        the pointer to the callable object. The lifetime of the callable
+   *        object is managed by the returned Function.
+   * \param args The arguments to construct TCallable
+   * \return A tuple of (Function, TCallable*)
+   */
+  template <typename TCallable, typename... Args>
+  static auto FromPackedInplace(Args&&... args) {
+    // We must ensure TCallable is a value type (decay_t) that can hold the 
callable object
+    static_assert(std::is_same_v<TCallable, std::decay_t<TCallable>>);
+    static_assert(std::is_invocable_v<TCallable, const AnyView*, int32_t, 
Any*>);
+    using ObjType = details::FunctionObjImpl<TCallable>;
+    Function func;
+    auto obj_ptr = make_object<ObjType>(std::forward<Args>(args)...);
+    auto* call_ptr = obj_ptr->GetCallable();
+    func.data_ = std::move(obj_ptr);
+    return std::make_tuple(std::move(func), call_ptr);
+  }
+
   /*!
    * \brief Create ffi::Function from a C style callbacks.
    *
diff --git a/include/tvm/ffi/reflection/overload.h 
b/include/tvm/ffi/reflection/overload.h
new file mode 100644
index 0000000..6556338
--- /dev/null
+++ b/include/tvm/ffi/reflection/overload.h
@@ -0,0 +1,501 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/reflection/overload.h
+ * \brief Registry of reflection metadata, supporting function overloading.
+ */
+#ifndef TVM_FFI_EXTRA_OVERLOAD_H
+#define TVM_FFI_EXTRA_OVERLOAD_H
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/variant.h>
+#include <tvm/ffi/function.h>
+#include <tvm/ffi/function_details.h>
+#include <tvm/ffi/optional.h>
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/string.h>
+#include <tvm/ffi/type_traits.h>
+
+#include <cstddef>
+#include <cstdint>
+#include <sstream>
+#include <string>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace ffi {
+
+namespace details {
+
+struct OverloadBase {
+ public:
+  // Try Call function pointer type, return true if matched and called
+  using FnPtr = bool (*)(OverloadBase*, const AnyView*, int32_t, Any*);
+
+  explicit OverloadBase(int32_t num_args, std::optional<std::string> name)
+      : num_args_(num_args),
+        name_(name ? std::move(*name) : ""),
+        name_ptr_(name ? &this->name_ : nullptr) {}
+
+  virtual void Register(std::unique_ptr<OverloadBase> overload) = 0;
+  virtual FnPtr GetTryCallPtr() = 0;
+  virtual void GetMismatchMessage(std::ostringstream& os, const AnyView* args,
+                                  int32_t num_args) = 0;
+
+  virtual ~OverloadBase() = default;
+  OverloadBase(const OverloadBase&) = delete;
+  OverloadBase& operator=(const OverloadBase&) = delete;
+
+ public:
+  static constexpr int32_t kAllMatched = -1;
+
+  // a fast cache for last matched arg index
+  // on 64-bit platform, this is packed in the same 8 byte with num_args_
+  int32_t last_mismatch_index_{kAllMatched};
+
+  // some constant helper args
+  const int32_t num_args_;
+  const std::string name_;
+  const std::string* const name_ptr_;
+};
+
+template <typename T>
+struct CaptureTupleAux;
+
+template <typename... Args>
+struct CaptureTupleAux<std::tuple<Args...>> {
+  using type = std::tuple<std::optional<std::decay_t<Args>>...>;
+};
+
+template <typename Callable>
+struct TypedOverload : OverloadBase {
+ public:
+  static_assert(std::is_same_v<Callable, std::decay_t<Callable>>, "Callable 
must be value type");
+
+  using FuncInfo = details::FunctionInfo<Callable>;
+  using PackedArgs = typename FuncInfo::ArgType;
+  using Ret = typename FuncInfo::RetType;
+  using CaptureTuple = typename CaptureTupleAux<PackedArgs>::type;
+  using OverloadBase::name_;
+  using OverloadBase::name_ptr_;
+  using typename OverloadBase::FnPtr;
+
+  static constexpr auto kNumArgs = FuncInfo::num_args;
+  static constexpr auto kSeq = std::make_index_sequence<kNumArgs>{};
+
+  explicit TypedOverload(const Callable& f, std::optional<std::string> name = 
std::nullopt)
+      : OverloadBase(kNumArgs, std::move(name)), f_(f) {}
+  explicit TypedOverload(Callable&& f, std::optional<std::string> name = 
std::nullopt)
+      : OverloadBase(kNumArgs, std::move(name)), f_(std::move(f)) {}
+
+  bool TryCall(const AnyView* args, int32_t num_args, Any* rv) {
+    if (num_args != kNumArgs) return false;
+    CaptureTuple captures{};
+    if (!TrySetAux(kSeq, captures, args)) return false;
+    // now all captures are set
+    if constexpr (std::is_same_v<Ret, void>) {
+      CallAux(kSeq, captures);
+      return true;
+    } else {
+      *rv = CallAux(kSeq, captures);
+      return true;
+    }
+  }
+
+  void Register(std::unique_ptr<OverloadBase> overload) override {
+    TVM_FFI_ICHECK(false) << "This should never be called.";
+  }
+
+  FnPtr GetTryCallPtr() final {
+    // lambda without a capture can be converted to function pointer
+    return [](OverloadBase* base, const AnyView* args, int32_t num_args, Any* 
rv) -> bool {
+      return static_cast<TypedOverload<Callable>*>(base)->TryCall(args, 
num_args, rv);
+    };
+  }
+
+  void GetMismatchMessage(std::ostringstream& os, const AnyView* args, int32_t 
num_args) final {
+    FGetFuncSignature f_sig = FuncInfo::Sig;
+    if (num_args != kNumArgs) {
+      os << "Mismatched number of arguments when calling: `" << name_ << " "
+         << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " << kNumArgs 
<< " arguments";
+    } else {
+      GetMismatchMessageAux<0>(os, args, num_args);
+    }
+  }
+
+ private:
+  template <std::size_t I>
+  void GetMismatchMessageAux(std::ostringstream& os, const AnyView* args, 
int32_t num_args) {
+    if constexpr (I < kNumArgs) {
+      if (this->last_mismatch_index_ == static_cast<int32_t>(I)) {
+        TVMFFIAny any_data = args[I].CopyToTVMFFIAny();
+        FGetFuncSignature f_sig = FuncInfo::Sig;
+        using Type = std::decay_t<std::tuple_element_t<I, PackedArgs>>;
+        os << "Mismatched type on argument #" << I << " when calling: `" << 
name_ << " "
+           << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected `" << 
Type2Str<Type>::v()
+           << "` but got `" << 
TypeTraits<Type>::GetMismatchTypeInfo(&any_data) << '`';
+      } else {
+        GetMismatchMessageAux<I + 1>(os, args, num_args);
+      }
+    }
+    // end of recursion
+  }
+
+  template <std::size_t... I>
+  Ret CallAux(std::index_sequence<I...>, CaptureTuple& tuple) {
+    /// NOTE: this works for T, const T, const T&, T&& argument types
+    return f_(static_cast<std::tuple_element_t<I, 
PackedArgs>>(std::move(*std::get<I>(tuple)))...);
+  }
+
+  template <std::size_t... I>
+  bool TrySetAux(std::index_sequence<I...>, CaptureTuple& tuple, const 
AnyView* args) {
+    return (TrySetOne<I>(tuple, args) && ...);
+  }
+
+  template <std::size_t I>
+  bool TrySetOne(CaptureTuple& tuple, const AnyView* args) {
+    using Type = std::decay_t<std::tuple_element_t<I, PackedArgs>>;
+    auto& capture = std::get<I>(tuple);
+    if constexpr (std::is_same_v<Type, AnyView>) {
+      capture = args[I];
+      return true;
+    } else if constexpr (std::is_same_v<Type, Any>) {
+      capture = Any(args[I]);
+      return true;
+    } else {
+      capture = args[I].template try_cast<Type>();
+      if (capture.has_value()) return true;
+      // slow path: record the last mismatch index
+      this->last_mismatch_index_ = static_cast<int32_t>(I);
+      return false;
+    }
+  }
+
+ protected:
+  Callable f_;
+};
+
+template <typename Callable>
+inline auto CreateNewOverload(Callable&& f, std::string name) {
+  using Type = TypedOverload<std::decay_t<Callable>>;
+  return std::make_unique<Type>(std::forward<Callable>(f), std::move(name));
+}
+
+template <typename Callable>
+struct OverloadedFunction : TypedOverload<Callable> {
+ public:
+  using TypedBase = TypedOverload<Callable>;
+  using OverloadBase::name_;
+  using OverloadBase::name_ptr_;
+  using TypedBase::GetTryCallPtr;
+  using TypedBase::kNumArgs;
+  using TypedBase::kSeq;
+  using TypedBase::TypedBase;  // constructors
+  using typename OverloadBase::FnPtr;
+  using typename TypedBase::Ret;
+
+  void Register(std::unique_ptr<OverloadBase> overload) final {
+    const auto fptr = overload->GetTryCallPtr();
+    overloads_.emplace_back(std::move(overload), fptr);
+  }
+
+  void operator()(const AnyView* args, int32_t num_args, Any* rv) {
+    // fast path: only add a little overhead when no overloads
+    if (overloads_.size() == 0) {
+      return unpack_call<Ret>(kSeq, name_ptr_, f_, args, num_args, rv);
+    }
+
+    // this can be inlined by compiler, don't worry
+    if (this->TryCall(args, num_args, rv)) return;
+
+    // virtual calls cannot be inlined, so we fast check the num_args first
+    // we also de-virtualize the fptr to reduce one more indirection
+    for (const auto& [overload, fptr] : overloads_) {
+      if (overload->num_args_ != num_args) continue;
+      if (fptr(overload.get(), args, num_args, rv)) return;
+    }
+
+    this->HandleOverloadFailure(args, num_args);
+  }
+
+ private:
+  void HandleOverloadFailure(const AnyView* args, int32_t num_args) {
+    std::ostringstream oss;
+    int32_t i = 0;
+    oss << "Overload #" << i++ << ": ";
+    this->GetMismatchMessage(oss, args, num_args);
+    for (const auto& [overload, _] : overloads_) {
+      oss << "\nOverload #" << i++ << ": ";
+      overload->GetMismatchMessage(oss, args, num_args);
+    }
+    TVM_FFI_THROW(TypeError) << "No matching overload found when calling: `" 
<< name_ << "` with "
+                             << num_args << " arguments:\n"
+                             << std::move(oss).str();
+  }
+  using TypedBase::f_;
+  std::vector<std::pair<std::unique_ptr<OverloadBase>, FnPtr>> overloads_;
+};
+
+}  // namespace details
+
+/*! \brief Reflection namespace */
+namespace reflection {
+
+/*!
+ * \brief Helper to register Object's reflection metadata.
+ * \tparam Class The class type.
+ *
+ * \code
+ *  namespace refl = tvm::ffi::reflection;
+ *  refl::ObjectDef<MyClass>().def_ro("my_field", &MyClass::my_field);
+ * \endcode
+ */
+template <typename Class>
+class OverloadObjectDef : private ObjectDef<Class> {
+ public:
+  using Super = ObjectDef<Class>;
+  /*!
+   * \brief Constructor
+   * \tparam ExtraArgs The extra arguments.
+   * \param extra_args The extra arguments.
+   */
+  template <typename... ExtraArgs>
+  explicit OverloadObjectDef(ExtraArgs&&... extra_args)
+      : Super(std::forward<ExtraArgs>(extra_args)...) {}
+
+  /*!
+   * \brief Define a readonly field.
+   *
+   * \tparam Class The class type.
+   * \tparam T The field type.
+   * \tparam Extra The extra arguments.
+   *
+   * \param name The name of the field.
+   * \param field_ptr The pointer to the field.
+   * \param extra The extra arguments that can be docstring or default value.
+   *
+   * \return The reflection definition.
+   */
+  template <typename T, typename BaseClass, typename... Extra>
+  TVM_FFI_INLINE OverloadObjectDef& def_ro(const char* name, T BaseClass::* 
field_ptr,
+                                           Extra&&... extra) {
+    /// NOTE: we don't allow properties to be overloaded
+    Super::def_ro(name, field_ptr, std::forward<Extra>(extra)...);
+    return *this;
+  }
+
+  /*!
+   * \brief Define a read-write field.
+   *
+   * \tparam Class The class type.
+   * \tparam T The field type.
+   * \tparam Extra The extra arguments.
+   *
+   * \param name The name of the field.
+   * \param field_ptr The pointer to the field.
+   * \param extra The extra arguments that can be docstring or default value.
+   *
+   * \return The reflection definition.
+   */
+  template <typename T, typename BaseClass, typename... Extra>
+  TVM_FFI_INLINE OverloadObjectDef& def_rw(const char* name, T BaseClass::* 
field_ptr,
+                                           Extra&&... extra) {
+    /// NOTE: we don't allow properties to be overloaded
+    Super::def_rw(name, field_ptr, std::forward<Extra>(extra)...);
+    return *this;
+  }
+
+  /*!
+   * \brief Define a method.
+   *
+   * \tparam Func The function type.
+   * \tparam Extra The extra arguments.
+   *
+   * \param name The name of the method.
+   * \param func The function to be registered.
+   * \param extra The extra arguments that can be docstring.
+   *
+   * \return The reflection definition.
+   */
+  template <typename Func, typename... Extra>
+  TVM_FFI_INLINE OverloadObjectDef& def(const char* name, Func&& func, 
Extra&&... extra) {
+    RegisterMethod(name, false, std::forward<Func>(func), 
std::forward<Extra>(extra)...);
+    return *this;
+  }
+
+  /*!
+   * \brief Define a static method.
+   *
+   * \tparam Func The function type.
+   * \tparam Extra The extra arguments.
+   *
+   * \param name The name of the method.
+   * \param func The function to be registered.
+   * \param extra The extra arguments that can be docstring.
+   *
+   * \return The reflection definition.
+   */
+  template <typename Func, typename... Extra>
+  TVM_FFI_INLINE OverloadObjectDef& def_static(const char* name, Func&& func, 
Extra&&... extra) {
+    RegisterMethod(name, true, std::forward<Func>(func), 
std::forward<Extra>(extra)...);
+    return *this;
+  }
+
+  /*!
+   * \brief Register a constructor for this object type.
+   *
+   * This method registers a static `__init__` method that constructs an 
instance
+   * of the object with the specified argument types. The constructor can be 
invoked
+   * from Python or other FFI bindings.
+   *
+   * \tparam Args The argument types for the constructor.
+   * \tparam Extra Additional arguments (e.g., docstring).
+   *
+   * \param init_func An instance of `init<Args...>` specifying constructor 
signature.
+   * \param extra Optional additional metadata such as docstring.
+   *
+   * \return Reference to this `ObjectDef` for method chaining.
+   *
+   * Example:
+   * \code
+   *   refl::ObjectDef<MyObject>()
+   *       .def(refl::init<int64_t, std::string>(), "Constructor docstring");
+   * \endcode
+   */
+  template <typename... Args, typename... Extra>
+  TVM_FFI_INLINE OverloadObjectDef& def([[maybe_unused]] init<Args...> 
init_func,
+                                        Extra&&... extra) {
+    RegisterMethod(kInitMethodName, true, &init<Args...>::template 
execute<Class>,
+                   std::forward<Extra>(extra)...);
+    return *this;
+  }
+
+ private:
+  using ReflectionDefBase::ApplyExtraInfoTrait;
+  using ReflectionDefBase::WrapFunction;
+  using Super::kInitMethodName;
+  using Super::type_index_;
+  using Super::type_key_;
+
+  template <typename Func>
+  static auto GetOverloadMethod(std::string name, Func&& func) {
+    using WrapFn = decltype(WrapFunction(std::forward<Func>(func)));
+    using OverloadFn = details::OverloadedFunction<std::decay_t<WrapFn>>;
+    return 
ffi::Function::FromPackedInplace<OverloadFn>(WrapFunction(std::forward<Func>(func)),
+                                                        std::move(name));
+  }
+
+  template <typename Func>
+  static auto NewOverload(std::string name, Func&& func) {
+    return details::CreateNewOverload(WrapFunction(std::forward<Func>(func)), 
std::move(name));
+  }
+
+  template <typename... ExtraArgs>
+  void RegisterExtraInfo(ExtraArgs&&... extra_args) {
+    TVMFFITypeMetadata info;
+    info.total_size = sizeof(Class);
+    info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind;
+    info.creator = nullptr;
+    info.doc = TVMFFIByteArray{nullptr, 0};
+    if constexpr (std::is_default_constructible_v<Class>) {
+      info.creator = ReflectionDefBase::ObjectCreatorDefault<Class>;
+    } else if constexpr (std::is_constructible_v<Class, UnsafeInit>) {
+      info.creator = ReflectionDefBase::ObjectCreatorUnsafeInit<Class>;
+    }
+    // apply extra info traits
+    ((ApplyExtraInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
+    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info));
+  }
+
+  template <typename T, typename BaseClass, typename... ExtraArgs>
+  void RegisterField(const char* name, T BaseClass::* field_ptr, bool writable,
+                     ExtraArgs&&... extra_args) {
+    static_assert(std::is_base_of_v<BaseClass, Class>, "BaseClass must be a 
base class of Class");
+    FieldInfoBuilder info;
+    info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
+    info.field_static_type_index = TypeToFieldStaticTypeIndex<T>::value;
+    // store byte offset and setter, getter
+    // so the same setter can be reused for all the same type
+    info.offset = GetFieldByteOffsetToObject<Class, T>(field_ptr);
+    info.size = sizeof(T);
+    info.alignment = alignof(T);
+    info.flags = 0;
+    if (writable) {
+      info.flags |= kTVMFFIFieldFlagBitMaskWritable;
+    }
+    info.getter = ReflectionDefBase::FieldGetter<T>;
+    info.setter = ReflectionDefBase::FieldSetter<T>;
+    // initialize default value to nullptr
+    info.default_value = AnyView(nullptr).CopyToTVMFFIAny();
+    info.doc = TVMFFIByteArray{nullptr, 0};
+    info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
+    // apply field info traits
+    ((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
+    // call register
+    std::string metadata_str = Metadata::ToJSON(info.metadata_);
+    info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()};
+    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info));
+  }
+
+  // register a method
+  template <typename Func, typename... Extra>
+  void RegisterMethod(const char* name, bool is_static, Func&& func, 
Extra&&... extra) {
+    using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
+    MethodInfoBuilder info;
+    info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
+    info.doc = TVMFFIByteArray{nullptr, 0};
+    info.flags = 0;
+    if (is_static) {
+      info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod;
+    }
+
+    auto method_name = std::string(type_key_) + "." + name;
+
+    // if an overload method exists, register to existing overload function
+    if (const auto overload_it = registered_fields_.find(name);
+        overload_it != registered_fields_.end()) {
+      details::OverloadBase* overload_ptr = overload_it->second;
+      return overload_ptr->Register(NewOverload(std::move(method_name), 
std::forward<Func>(func)));
+    }
+
+    // first time registering overload method
+    auto [method, overload_ptr] =
+        GetOverloadMethod(std::move(method_name), std::forward<Func>(func));
+    registered_fields_.try_emplace(name, overload_ptr);
+
+    info.method = AnyView(method).CopyToTVMFFIAny();
+    info.metadata_.emplace_back("type_schema", FuncInfo::TypeSchema());
+    // apply method info traits
+    ((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
+    std::string metadata_str = Metadata::ToJSON(info.metadata_);
+    info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()};
+    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
+  }
+
+  std::unordered_map<std::string, details::OverloadBase*> registered_fields_;
+};
+
+}  // namespace reflection
+}  // namespace ffi
+}  // namespace tvm
+#endif  // TVM_FFI_EXTRA_OVERLOAD_H
diff --git a/include/tvm/ffi/reflection/registry.h 
b/include/tvm/ffi/reflection/registry.h
index 3224a9f..1dc22ae 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -28,6 +28,7 @@
 #include <tvm/ffi/container/map.h>
 #include <tvm/ffi/container/variant.h>
 #include <tvm/ffi/function.h>
+#include <tvm/ffi/function_details.h>
 #include <tvm/ffi/optional.h>
 #include <tvm/ffi/string.h>
 #include <tvm/ffi/type_traits.h>
@@ -36,6 +37,7 @@
 #include <optional>
 #include <sstream>
 #include <string>
+#include <type_traits>
 #include <utility>
 #include <vector>
 
@@ -94,6 +96,8 @@ class Metadata : public InfoTrait {
   friend class GlobalDef;
   template <typename T>
   friend class ObjectDef;
+  template <typename T>
+  friend class OverloadObjectDef;
   /*!
    * \brief Move metadata into a vector of key-value pairs.
    * \param out The output vector.
@@ -270,52 +274,49 @@ class ReflectionDefBase {
     }
   }
 
+  template <typename Func>
+  TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) {
+    return ffi::Function::FromTyped(WrapFunction(std::forward<Func>(func)), 
std::move(name));
+  }
+
+  template <typename Func>
+  TVM_FFI_INLINE static Func&& WrapFunction(Func&& func) {
+    return std::forward<Func>(func);
+  }
   template <typename Class, typename R, typename... Args>
-  TVM_FFI_INLINE static Function GetMethod(std::string name, R 
(Class::*func)(Args...)) {
+  TVM_FFI_INLINE static auto WrapFunction(R (Class::*func)(Args...)) {
     static_assert(std::is_base_of_v<ObjectRef, Class> || 
std::is_base_of_v<Object, Class>,
                   "Class must be derived from ObjectRef or Object");
     if constexpr (std::is_base_of_v<ObjectRef, Class>) {
-      auto fwrap = [func](Class target, Args... params) -> R {
+      return [func](Class target, Args... params) -> R {
         // call method pointer
         return (target.*func)(std::forward<Args>(params)...);
       };
-      return ffi::Function::FromTyped(fwrap, std::move(name));
     }
-
     if constexpr (std::is_base_of_v<Object, Class>) {
-      auto fwrap = [func](const Class* target, Args... params) -> R {
+      return [func](const Class* target, Args... params) -> R {
         // call method pointer
         return 
(const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
       };
-      return ffi::Function::FromTyped(fwrap, std::move(name));
     }
   }
-
   template <typename Class, typename R, typename... Args>
-  TVM_FFI_INLINE static Function GetMethod(std::string name, R 
(Class::*func)(Args...) const) {
+  TVM_FFI_INLINE static auto WrapFunction(R (Class::*func)(Args...) const) {
     static_assert(std::is_base_of_v<ObjectRef, Class> || 
std::is_base_of_v<Object, Class>,
                   "Class must be derived from ObjectRef or Object");
     if constexpr (std::is_base_of_v<ObjectRef, Class>) {
-      auto fwrap = [func](const Class& target, Args... params) -> R {
+      return [func](const Class& target, Args... params) -> R {
         // call method pointer
         return (target.*func)(std::forward<Args>(params)...);
       };
-      return ffi::Function::FromTyped(fwrap, std::move(name));
     }
-
     if constexpr (std::is_base_of_v<Object, Class>) {
-      auto fwrap = [func](const Class* target, Args... params) -> R {
+      return [func](const Class* target, Args... params) -> R {
         // call method pointer
         return (target->*func)(std::forward<Args>(params)...);
       };
-      return ffi::Function::FromTyped(fwrap, std::move(name));
     }
   }
-
-  template <typename Func>
-  TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) {
-    return ffi::Function::FromTyped(std::forward<Func>(func), std::move(name));
-  }
 };
 /// \endcond
 
@@ -438,6 +439,8 @@ struct init {
   // Allow ObjectDef to access the execute function
   template <typename Class>
   friend class ObjectDef;
+  template <typename T>
+  friend class OverloadObjectDef;
 
   /*!
    * \brief Constructor
@@ -585,6 +588,9 @@ class ObjectDef : public ReflectionDefBase {
   }
 
  private:
+  template <typename T>
+  friend class OverloadObjectDef;
+
   template <typename... ExtraArgs>
   void RegisterExtraInfo(ExtraArgs&&... extra_args) {
     TVMFFITypeMetadata info;
@@ -643,6 +649,7 @@ class ObjectDef : public ReflectionDefBase {
     if (is_static) {
       info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod;
     }
+
     // obtain the method function
     Function method = GetMethod(std::string(type_key_) + "." + name, 
std::forward<Func>(func));
     info.method = AnyView(method).CopyToTVMFFIAny();
diff --git a/tests/cpp/test_overload.cc b/tests/cpp/test_overload.cc
new file mode 100644
index 0000000..7dfb9c7
--- /dev/null
+++ b/tests/cpp/test_overload.cc
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/reflection/access_path.h>
+#include <tvm/ffi/reflection/accessor.h>
+#include <tvm/ffi/reflection/creator.h>
+#include <tvm/ffi/reflection/overload.h>
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/string.h>
+
+namespace {
+
+using namespace tvm::ffi;
+
+struct TestOverloadObj : public Object {
+  explicit TestOverloadObj(int32_t x) : type(Type::INT) {}
+  explicit TestOverloadObj(float y) : type(Type::FLOAT) {}
+
+  static int AddOneInt(int x) { return x + 1; }
+  static float AddOneFloat(float x) { return x + 1.0f; }
+
+  template <typename T>
+  auto Holds(T) const {
+    if constexpr (std::is_same_v<T, int32_t>) {
+      return type == Type::INT;
+    } else if constexpr (std::is_same_v<T, float>) {
+      return type == Type::FLOAT;
+    } else {
+      static_assert(sizeof(T) == 0, "Unsupported type");
+    }
+  }
+
+  enum class Type { INT, FLOAT } type;
+  TVM_FFI_DECLARE_OBJECT_INFO("test.TestOverloadObj", TestOverloadObj, Object);
+};
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::OverloadObjectDef<TestOverloadObj>()
+      .def(refl::init<int32_t>())
+      .def(refl::init<float>())
+      .def("hold_same_type", &TestOverloadObj::Holds<int32_t>)
+      .def("hold_same_type", &TestOverloadObj::Holds<float>)
+      .def_static("add_one_static", &TestOverloadObj::AddOneInt)
+      .def_static("add_one_static", &TestOverloadObj::AddOneFloat);
+}
+
+TEST(Reflection, CallOverloadedInitMethod) {
+  Function init_method = reflection::GetMethod("test.TestOverloadObj", 
"__ffi_init__");
+  Any obj_a = init_method(10);  // choose the int constructor
+  EXPECT_TRUE(obj_a.as<TestOverloadObj>() != nullptr);
+  EXPECT_EQ(obj_a.as<TestOverloadObj>()->type, TestOverloadObj::Type::INT);
+  Any obj_b = init_method(3.14f);  // choose the float constructor
+  EXPECT_TRUE(obj_b.as<TestOverloadObj>() != nullptr);
+  EXPECT_EQ(obj_b.as<TestOverloadObj>()->type, TestOverloadObj::Type::FLOAT);
+}
+
+TEST(Reflection, CallOverloadedMethod) {
+  Function init_method = reflection::GetMethod("test.TestOverloadObj", 
"__ffi_init__");
+  Function hold_same_type = reflection::GetMethod("test.TestOverloadObj", 
"hold_same_type");
+  Any obj_a = init_method(10);  // choose the int constructor
+  Any res_a = hold_same_type(obj_a, 20);
+  EXPECT_EQ(res_a.as<bool>(), true);
+  Any res_b = hold_same_type(obj_a, 3.14f);
+  EXPECT_EQ(res_b.as<bool>(), false);
+}
+
+TEST(Reflection, CallOverloadedStaticMethod) {
+  Function add_one = reflection::GetMethod("test.TestOverloadObj", 
"add_one_static");
+  Any res_a = add_one(20);
+  EXPECT_EQ(res_a.as<int>(), 21);
+  Any res_b = add_one(1.0f);
+  static_assert(1.0f + 1.0f == 2.0f);
+  EXPECT_EQ(res_b.as<float>(), 2.0f);
+}
+
+}  // namespace

Reply via email to