This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 84c5bdb [Feature] Support dynamic-style overload for FFI object types
(#286)
84c5bdb is described below
commit 84c5bdbcd11ae0e5bc921db4f35d3fd1d3ae762c
Author: DarkSharpness <[email protected]>
AuthorDate: Tue Dec 23 14:45:32 2025 +0800
[Feature] Support dynamic-style overload for FFI object types (#286)
Related discussion here #265 .
Modification in short:
1. Add `Function::FromPackedInplace` in `function.h` and generalize some
methods in `registry.h`
2. Implement FFI object types overload method (all in
`extra/overload.h`)
3. Add a simple test.
---
include/tvm/ffi/function.h | 47 +++-
include/tvm/ffi/reflection/overload.h | 501 ++++++++++++++++++++++++++++++++++
include/tvm/ffi/reflection/registry.h | 43 +--
tests/cpp/test_overload.cc | 95 +++++++
4 files changed, 658 insertions(+), 28 deletions(-)
diff --git a/include/tvm/ffi/function.h b/include/tvm/ffi/function.h
index d1cc693..3043731 100644
--- a/include/tvm/ffi/function.h
+++ b/include/tvm/ffi/function.h
@@ -33,6 +33,10 @@
#define TVM_FFI_DLL_EXPORT_INCLUDE_METADATA 0
#endif
+#if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA
+#include <sstream>
+#endif // TVM_FFI_DLL_EXPORT_INCLUDE_METADATA
+
#include <tvm/ffi/any.h>
#include <tvm/ffi/base_details.h>
#include <tvm/ffi/c_api.h>
@@ -40,7 +44,9 @@
#include <tvm/ffi/function_details.h>
#include <functional>
+#include <optional>
#include <string>
+#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
@@ -165,21 +171,19 @@ class FunctionObjImpl : public FunctionObj {
/*!
* \brief Derived object class for constructing ffi::FunctionObj.
- * \param callable The type-erased callable object (rvalue).
- */
- explicit FunctionObjImpl(TCallable&& callable) :
callable_(std::move(callable)) {
- this->safe_call = SafeCall;
- this->cpp_call = reinterpret_cast<void*>(CppCall);
- }
- /*!
- * \brief Derived object class for constructing ffi::FunctionObj.
- * \param callable The type-erased callable object (lvalue).
+ * \param args The arguments to construct TCallable
*/
- explicit FunctionObjImpl(const TCallable& callable) : callable_(callable) {
+ template <typename... Args>
+ explicit FunctionObjImpl(Args&&... args) :
callable_(std::forward<Args>(args)...) {
this->safe_call = SafeCall;
this->cpp_call = reinterpret_cast<void*>(CppCall);
}
+ FunctionObjImpl(const FunctionObjImpl&) = delete;
+ FunctionObjImpl& operator=(const FunctionObjImpl&) = delete;
+
+ TCallable* GetCallable() { return &callable_; }
+
private:
// implementation of call
static void CppCall(const FunctionObj* func, const AnyView* args, int32_t
num_args, Any* result) {
@@ -356,6 +360,29 @@ class Function : public ObjectRef {
}
}
+ /*!
+ * \brief Constructing a packed function from a callable type
+ * whose signature is consistent with `ffi::Function`.
+ * It will create the Callable object with the given arguments,
+ * and return the inplace constructed Function along with
+ * the pointer to the callable object. The lifetime of the callable
+ * object is managed by the returned Function.
+ * \param args The arguments to construct TCallable
+ * \return A tuple of (Function, TCallable*)
+ */
+ template <typename TCallable, typename... Args>
+ static auto FromPackedInplace(Args&&... args) {
+ // We must ensure TCallable is a value type (decay_t) that can hold the
callable object
+ static_assert(std::is_same_v<TCallable, std::decay_t<TCallable>>);
+ static_assert(std::is_invocable_v<TCallable, const AnyView*, int32_t,
Any*>);
+ using ObjType = details::FunctionObjImpl<TCallable>;
+ Function func;
+ auto obj_ptr = make_object<ObjType>(std::forward<Args>(args)...);
+ auto* call_ptr = obj_ptr->GetCallable();
+ func.data_ = std::move(obj_ptr);
+ return std::make_tuple(std::move(func), call_ptr);
+ }
+
/*!
* \brief Create ffi::Function from a C style callbacks.
*
diff --git a/include/tvm/ffi/reflection/overload.h
b/include/tvm/ffi/reflection/overload.h
new file mode 100644
index 0000000..6556338
--- /dev/null
+++ b/include/tvm/ffi/reflection/overload.h
@@ -0,0 +1,501 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/reflection/overload.h
+ * \brief Registry of reflection metadata, supporting function overloading.
+ */
+#ifndef TVM_FFI_EXTRA_OVERLOAD_H
+#define TVM_FFI_EXTRA_OVERLOAD_H
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/variant.h>
+#include <tvm/ffi/function.h>
+#include <tvm/ffi/function_details.h>
+#include <tvm/ffi/optional.h>
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/string.h>
+#include <tvm/ffi/type_traits.h>
+
+#include <cstddef>
+#include <cstdint>
+#include <sstream>
+#include <string>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace ffi {
+
+namespace details {
+
+struct OverloadBase {
+ public:
+ // Try Call function pointer type, return true if matched and called
+ using FnPtr = bool (*)(OverloadBase*, const AnyView*, int32_t, Any*);
+
+ explicit OverloadBase(int32_t num_args, std::optional<std::string> name)
+ : num_args_(num_args),
+ name_(name ? std::move(*name) : ""),
+ name_ptr_(name ? &this->name_ : nullptr) {}
+
+ virtual void Register(std::unique_ptr<OverloadBase> overload) = 0;
+ virtual FnPtr GetTryCallPtr() = 0;
+ virtual void GetMismatchMessage(std::ostringstream& os, const AnyView* args,
+ int32_t num_args) = 0;
+
+ virtual ~OverloadBase() = default;
+ OverloadBase(const OverloadBase&) = delete;
+ OverloadBase& operator=(const OverloadBase&) = delete;
+
+ public:
+ static constexpr int32_t kAllMatched = -1;
+
+ // a fast cache for last matched arg index
+ // on 64-bit platform, this is packed in the same 8 byte with num_args_
+ int32_t last_mismatch_index_{kAllMatched};
+
+ // some constant helper args
+ const int32_t num_args_;
+ const std::string name_;
+ const std::string* const name_ptr_;
+};
+
+template <typename T>
+struct CaptureTupleAux;
+
+template <typename... Args>
+struct CaptureTupleAux<std::tuple<Args...>> {
+ using type = std::tuple<std::optional<std::decay_t<Args>>...>;
+};
+
+template <typename Callable>
+struct TypedOverload : OverloadBase {
+ public:
+ static_assert(std::is_same_v<Callable, std::decay_t<Callable>>, "Callable
must be value type");
+
+ using FuncInfo = details::FunctionInfo<Callable>;
+ using PackedArgs = typename FuncInfo::ArgType;
+ using Ret = typename FuncInfo::RetType;
+ using CaptureTuple = typename CaptureTupleAux<PackedArgs>::type;
+ using OverloadBase::name_;
+ using OverloadBase::name_ptr_;
+ using typename OverloadBase::FnPtr;
+
+ static constexpr auto kNumArgs = FuncInfo::num_args;
+ static constexpr auto kSeq = std::make_index_sequence<kNumArgs>{};
+
+ explicit TypedOverload(const Callable& f, std::optional<std::string> name =
std::nullopt)
+ : OverloadBase(kNumArgs, std::move(name)), f_(f) {}
+ explicit TypedOverload(Callable&& f, std::optional<std::string> name =
std::nullopt)
+ : OverloadBase(kNumArgs, std::move(name)), f_(std::move(f)) {}
+
+ bool TryCall(const AnyView* args, int32_t num_args, Any* rv) {
+ if (num_args != kNumArgs) return false;
+ CaptureTuple captures{};
+ if (!TrySetAux(kSeq, captures, args)) return false;
+ // now all captures are set
+ if constexpr (std::is_same_v<Ret, void>) {
+ CallAux(kSeq, captures);
+ return true;
+ } else {
+ *rv = CallAux(kSeq, captures);
+ return true;
+ }
+ }
+
+ void Register(std::unique_ptr<OverloadBase> overload) override {
+ TVM_FFI_ICHECK(false) << "This should never be called.";
+ }
+
+ FnPtr GetTryCallPtr() final {
+ // lambda without a capture can be converted to function pointer
+ return [](OverloadBase* base, const AnyView* args, int32_t num_args, Any*
rv) -> bool {
+ return static_cast<TypedOverload<Callable>*>(base)->TryCall(args,
num_args, rv);
+ };
+ }
+
+ void GetMismatchMessage(std::ostringstream& os, const AnyView* args, int32_t
num_args) final {
+ FGetFuncSignature f_sig = FuncInfo::Sig;
+ if (num_args != kNumArgs) {
+ os << "Mismatched number of arguments when calling: `" << name_ << " "
+ << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " << kNumArgs
<< " arguments";
+ } else {
+ GetMismatchMessageAux<0>(os, args, num_args);
+ }
+ }
+
+ private:
+ template <std::size_t I>
+ void GetMismatchMessageAux(std::ostringstream& os, const AnyView* args,
int32_t num_args) {
+ if constexpr (I < kNumArgs) {
+ if (this->last_mismatch_index_ == static_cast<int32_t>(I)) {
+ TVMFFIAny any_data = args[I].CopyToTVMFFIAny();
+ FGetFuncSignature f_sig = FuncInfo::Sig;
+ using Type = std::decay_t<std::tuple_element_t<I, PackedArgs>>;
+ os << "Mismatched type on argument #" << I << " when calling: `" <<
name_ << " "
+ << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected `" <<
Type2Str<Type>::v()
+ << "` but got `" <<
TypeTraits<Type>::GetMismatchTypeInfo(&any_data) << '`';
+ } else {
+ GetMismatchMessageAux<I + 1>(os, args, num_args);
+ }
+ }
+ // end of recursion
+ }
+
+ template <std::size_t... I>
+ Ret CallAux(std::index_sequence<I...>, CaptureTuple& tuple) {
+ /// NOTE: this works for T, const T, const T&, T&& argument types
+ return f_(static_cast<std::tuple_element_t<I,
PackedArgs>>(std::move(*std::get<I>(tuple)))...);
+ }
+
+ template <std::size_t... I>
+ bool TrySetAux(std::index_sequence<I...>, CaptureTuple& tuple, const
AnyView* args) {
+ return (TrySetOne<I>(tuple, args) && ...);
+ }
+
+ template <std::size_t I>
+ bool TrySetOne(CaptureTuple& tuple, const AnyView* args) {
+ using Type = std::decay_t<std::tuple_element_t<I, PackedArgs>>;
+ auto& capture = std::get<I>(tuple);
+ if constexpr (std::is_same_v<Type, AnyView>) {
+ capture = args[I];
+ return true;
+ } else if constexpr (std::is_same_v<Type, Any>) {
+ capture = Any(args[I]);
+ return true;
+ } else {
+ capture = args[I].template try_cast<Type>();
+ if (capture.has_value()) return true;
+ // slow path: record the last mismatch index
+ this->last_mismatch_index_ = static_cast<int32_t>(I);
+ return false;
+ }
+ }
+
+ protected:
+ Callable f_;
+};
+
+template <typename Callable>
+inline auto CreateNewOverload(Callable&& f, std::string name) {
+ using Type = TypedOverload<std::decay_t<Callable>>;
+ return std::make_unique<Type>(std::forward<Callable>(f), std::move(name));
+}
+
+template <typename Callable>
+struct OverloadedFunction : TypedOverload<Callable> {
+ public:
+ using TypedBase = TypedOverload<Callable>;
+ using OverloadBase::name_;
+ using OverloadBase::name_ptr_;
+ using TypedBase::GetTryCallPtr;
+ using TypedBase::kNumArgs;
+ using TypedBase::kSeq;
+ using TypedBase::TypedBase; // constructors
+ using typename OverloadBase::FnPtr;
+ using typename TypedBase::Ret;
+
+ void Register(std::unique_ptr<OverloadBase> overload) final {
+ const auto fptr = overload->GetTryCallPtr();
+ overloads_.emplace_back(std::move(overload), fptr);
+ }
+
+ void operator()(const AnyView* args, int32_t num_args, Any* rv) {
+ // fast path: only add a little overhead when no overloads
+ if (overloads_.size() == 0) {
+ return unpack_call<Ret>(kSeq, name_ptr_, f_, args, num_args, rv);
+ }
+
+ // this can be inlined by compiler, don't worry
+ if (this->TryCall(args, num_args, rv)) return;
+
+ // virtual calls cannot be inlined, so we fast check the num_args first
+ // we also de-virtualize the fptr to reduce one more indirection
+ for (const auto& [overload, fptr] : overloads_) {
+ if (overload->num_args_ != num_args) continue;
+ if (fptr(overload.get(), args, num_args, rv)) return;
+ }
+
+ this->HandleOverloadFailure(args, num_args);
+ }
+
+ private:
+ void HandleOverloadFailure(const AnyView* args, int32_t num_args) {
+ std::ostringstream oss;
+ int32_t i = 0;
+ oss << "Overload #" << i++ << ": ";
+ this->GetMismatchMessage(oss, args, num_args);
+ for (const auto& [overload, _] : overloads_) {
+ oss << "\nOverload #" << i++ << ": ";
+ overload->GetMismatchMessage(oss, args, num_args);
+ }
+ TVM_FFI_THROW(TypeError) << "No matching overload found when calling: `"
<< name_ << "` with "
+ << num_args << " arguments:\n"
+ << std::move(oss).str();
+ }
+ using TypedBase::f_;
+ std::vector<std::pair<std::unique_ptr<OverloadBase>, FnPtr>> overloads_;
+};
+
+} // namespace details
+
+/*! \brief Reflection namespace */
+namespace reflection {
+
+/*!
+ * \brief Helper to register Object's reflection metadata.
+ * \tparam Class The class type.
+ *
+ * \code
+ * namespace refl = tvm::ffi::reflection;
+ * refl::ObjectDef<MyClass>().def_ro("my_field", &MyClass::my_field);
+ * \endcode
+ */
+template <typename Class>
+class OverloadObjectDef : private ObjectDef<Class> {
+ public:
+ using Super = ObjectDef<Class>;
+ /*!
+ * \brief Constructor
+ * \tparam ExtraArgs The extra arguments.
+ * \param extra_args The extra arguments.
+ */
+ template <typename... ExtraArgs>
+ explicit OverloadObjectDef(ExtraArgs&&... extra_args)
+ : Super(std::forward<ExtraArgs>(extra_args)...) {}
+
+ /*!
+ * \brief Define a readonly field.
+ *
+ * \tparam Class The class type.
+ * \tparam T The field type.
+ * \tparam Extra The extra arguments.
+ *
+ * \param name The name of the field.
+ * \param field_ptr The pointer to the field.
+ * \param extra The extra arguments that can be docstring or default value.
+ *
+ * \return The reflection definition.
+ */
+ template <typename T, typename BaseClass, typename... Extra>
+ TVM_FFI_INLINE OverloadObjectDef& def_ro(const char* name, T BaseClass::*
field_ptr,
+ Extra&&... extra) {
+ /// NOTE: we don't allow properties to be overloaded
+ Super::def_ro(name, field_ptr, std::forward<Extra>(extra)...);
+ return *this;
+ }
+
+ /*!
+ * \brief Define a read-write field.
+ *
+ * \tparam Class The class type.
+ * \tparam T The field type.
+ * \tparam Extra The extra arguments.
+ *
+ * \param name The name of the field.
+ * \param field_ptr The pointer to the field.
+ * \param extra The extra arguments that can be docstring or default value.
+ *
+ * \return The reflection definition.
+ */
+ template <typename T, typename BaseClass, typename... Extra>
+ TVM_FFI_INLINE OverloadObjectDef& def_rw(const char* name, T BaseClass::*
field_ptr,
+ Extra&&... extra) {
+ /// NOTE: we don't allow properties to be overloaded
+ Super::def_rw(name, field_ptr, std::forward<Extra>(extra)...);
+ return *this;
+ }
+
+ /*!
+ * \brief Define a method.
+ *
+ * \tparam Func The function type.
+ * \tparam Extra The extra arguments.
+ *
+ * \param name The name of the method.
+ * \param func The function to be registered.
+ * \param extra The extra arguments that can be docstring.
+ *
+ * \return The reflection definition.
+ */
+ template <typename Func, typename... Extra>
+ TVM_FFI_INLINE OverloadObjectDef& def(const char* name, Func&& func,
Extra&&... extra) {
+ RegisterMethod(name, false, std::forward<Func>(func),
std::forward<Extra>(extra)...);
+ return *this;
+ }
+
+ /*!
+ * \brief Define a static method.
+ *
+ * \tparam Func The function type.
+ * \tparam Extra The extra arguments.
+ *
+ * \param name The name of the method.
+ * \param func The function to be registered.
+ * \param extra The extra arguments that can be docstring.
+ *
+ * \return The reflection definition.
+ */
+ template <typename Func, typename... Extra>
+ TVM_FFI_INLINE OverloadObjectDef& def_static(const char* name, Func&& func,
Extra&&... extra) {
+ RegisterMethod(name, true, std::forward<Func>(func),
std::forward<Extra>(extra)...);
+ return *this;
+ }
+
+ /*!
+ * \brief Register a constructor for this object type.
+ *
+ * This method registers a static `__init__` method that constructs an
instance
+ * of the object with the specified argument types. The constructor can be
invoked
+ * from Python or other FFI bindings.
+ *
+ * \tparam Args The argument types for the constructor.
+ * \tparam Extra Additional arguments (e.g., docstring).
+ *
+ * \param init_func An instance of `init<Args...>` specifying constructor
signature.
+ * \param extra Optional additional metadata such as docstring.
+ *
+ * \return Reference to this `ObjectDef` for method chaining.
+ *
+ * Example:
+ * \code
+ * refl::ObjectDef<MyObject>()
+ * .def(refl::init<int64_t, std::string>(), "Constructor docstring");
+ * \endcode
+ */
+ template <typename... Args, typename... Extra>
+ TVM_FFI_INLINE OverloadObjectDef& def([[maybe_unused]] init<Args...>
init_func,
+ Extra&&... extra) {
+ RegisterMethod(kInitMethodName, true, &init<Args...>::template
execute<Class>,
+ std::forward<Extra>(extra)...);
+ return *this;
+ }
+
+ private:
+ using ReflectionDefBase::ApplyExtraInfoTrait;
+ using ReflectionDefBase::WrapFunction;
+ using Super::kInitMethodName;
+ using Super::type_index_;
+ using Super::type_key_;
+
+ template <typename Func>
+ static auto GetOverloadMethod(std::string name, Func&& func) {
+ using WrapFn = decltype(WrapFunction(std::forward<Func>(func)));
+ using OverloadFn = details::OverloadedFunction<std::decay_t<WrapFn>>;
+ return
ffi::Function::FromPackedInplace<OverloadFn>(WrapFunction(std::forward<Func>(func)),
+ std::move(name));
+ }
+
+ template <typename Func>
+ static auto NewOverload(std::string name, Func&& func) {
+ return details::CreateNewOverload(WrapFunction(std::forward<Func>(func)),
std::move(name));
+ }
+
+ template <typename... ExtraArgs>
+ void RegisterExtraInfo(ExtraArgs&&... extra_args) {
+ TVMFFITypeMetadata info;
+ info.total_size = sizeof(Class);
+ info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind;
+ info.creator = nullptr;
+ info.doc = TVMFFIByteArray{nullptr, 0};
+ if constexpr (std::is_default_constructible_v<Class>) {
+ info.creator = ReflectionDefBase::ObjectCreatorDefault<Class>;
+ } else if constexpr (std::is_constructible_v<Class, UnsafeInit>) {
+ info.creator = ReflectionDefBase::ObjectCreatorUnsafeInit<Class>;
+ }
+ // apply extra info traits
+ ((ApplyExtraInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info));
+ }
+
+ template <typename T, typename BaseClass, typename... ExtraArgs>
+ void RegisterField(const char* name, T BaseClass::* field_ptr, bool writable,
+ ExtraArgs&&... extra_args) {
+ static_assert(std::is_base_of_v<BaseClass, Class>, "BaseClass must be a
base class of Class");
+ FieldInfoBuilder info;
+ info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
+ info.field_static_type_index = TypeToFieldStaticTypeIndex<T>::value;
+ // store byte offset and setter, getter
+ // so the same setter can be reused for all the same type
+ info.offset = GetFieldByteOffsetToObject<Class, T>(field_ptr);
+ info.size = sizeof(T);
+ info.alignment = alignof(T);
+ info.flags = 0;
+ if (writable) {
+ info.flags |= kTVMFFIFieldFlagBitMaskWritable;
+ }
+ info.getter = ReflectionDefBase::FieldGetter<T>;
+ info.setter = ReflectionDefBase::FieldSetter<T>;
+ // initialize default value to nullptr
+ info.default_value = AnyView(nullptr).CopyToTVMFFIAny();
+ info.doc = TVMFFIByteArray{nullptr, 0};
+ info.metadata_.emplace_back("type_schema", details::TypeSchema<T>::v());
+ // apply field info traits
+ ((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
+ // call register
+ std::string metadata_str = Metadata::ToJSON(info.metadata_);
+ info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()};
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info));
+ }
+
+ // register a method
+ template <typename Func, typename... Extra>
+ void RegisterMethod(const char* name, bool is_static, Func&& func,
Extra&&... extra) {
+ using FuncInfo = details::FunctionInfo<std::decay_t<Func>>;
+ MethodInfoBuilder info;
+ info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
+ info.doc = TVMFFIByteArray{nullptr, 0};
+ info.flags = 0;
+ if (is_static) {
+ info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod;
+ }
+
+ auto method_name = std::string(type_key_) + "." + name;
+
+ // if an overload method exists, register to existing overload function
+ if (const auto overload_it = registered_fields_.find(name);
+ overload_it != registered_fields_.end()) {
+ details::OverloadBase* overload_ptr = overload_it->second;
+ return overload_ptr->Register(NewOverload(std::move(method_name),
std::forward<Func>(func)));
+ }
+
+ // first time registering overload method
+ auto [method, overload_ptr] =
+ GetOverloadMethod(std::move(method_name), std::forward<Func>(func));
+ registered_fields_.try_emplace(name, overload_ptr);
+
+ info.method = AnyView(method).CopyToTVMFFIAny();
+ info.metadata_.emplace_back("type_schema", FuncInfo::TypeSchema());
+ // apply method info traits
+ ((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
+ std::string metadata_str = Metadata::ToJSON(info.metadata_);
+ info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()};
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
+ }
+
+ std::unordered_map<std::string, details::OverloadBase*> registered_fields_;
+};
+
+} // namespace reflection
+} // namespace ffi
+} // namespace tvm
+#endif // TVM_FFI_EXTRA_OVERLOAD_H
diff --git a/include/tvm/ffi/reflection/registry.h
b/include/tvm/ffi/reflection/registry.h
index 3224a9f..1dc22ae 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -28,6 +28,7 @@
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/variant.h>
#include <tvm/ffi/function.h>
+#include <tvm/ffi/function_details.h>
#include <tvm/ffi/optional.h>
#include <tvm/ffi/string.h>
#include <tvm/ffi/type_traits.h>
@@ -36,6 +37,7 @@
#include <optional>
#include <sstream>
#include <string>
+#include <type_traits>
#include <utility>
#include <vector>
@@ -94,6 +96,8 @@ class Metadata : public InfoTrait {
friend class GlobalDef;
template <typename T>
friend class ObjectDef;
+ template <typename T>
+ friend class OverloadObjectDef;
/*!
* \brief Move metadata into a vector of key-value pairs.
* \param out The output vector.
@@ -270,52 +274,49 @@ class ReflectionDefBase {
}
}
+ template <typename Func>
+ TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) {
+ return ffi::Function::FromTyped(WrapFunction(std::forward<Func>(func)),
std::move(name));
+ }
+
+ template <typename Func>
+ TVM_FFI_INLINE static Func&& WrapFunction(Func&& func) {
+ return std::forward<Func>(func);
+ }
template <typename Class, typename R, typename... Args>
- TVM_FFI_INLINE static Function GetMethod(std::string name, R
(Class::*func)(Args...)) {
+ TVM_FFI_INLINE static auto WrapFunction(R (Class::*func)(Args...)) {
static_assert(std::is_base_of_v<ObjectRef, Class> ||
std::is_base_of_v<Object, Class>,
"Class must be derived from ObjectRef or Object");
if constexpr (std::is_base_of_v<ObjectRef, Class>) {
- auto fwrap = [func](Class target, Args... params) -> R {
+ return [func](Class target, Args... params) -> R {
// call method pointer
return (target.*func)(std::forward<Args>(params)...);
};
- return ffi::Function::FromTyped(fwrap, std::move(name));
}
-
if constexpr (std::is_base_of_v<Object, Class>) {
- auto fwrap = [func](const Class* target, Args... params) -> R {
+ return [func](const Class* target, Args... params) -> R {
// call method pointer
return
(const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
};
- return ffi::Function::FromTyped(fwrap, std::move(name));
}
}
-
template <typename Class, typename R, typename... Args>
- TVM_FFI_INLINE static Function GetMethod(std::string name, R
(Class::*func)(Args...) const) {
+ TVM_FFI_INLINE static auto WrapFunction(R (Class::*func)(Args...) const) {
static_assert(std::is_base_of_v<ObjectRef, Class> ||
std::is_base_of_v<Object, Class>,
"Class must be derived from ObjectRef or Object");
if constexpr (std::is_base_of_v<ObjectRef, Class>) {
- auto fwrap = [func](const Class& target, Args... params) -> R {
+ return [func](const Class& target, Args... params) -> R {
// call method pointer
return (target.*func)(std::forward<Args>(params)...);
};
- return ffi::Function::FromTyped(fwrap, std::move(name));
}
-
if constexpr (std::is_base_of_v<Object, Class>) {
- auto fwrap = [func](const Class* target, Args... params) -> R {
+ return [func](const Class* target, Args... params) -> R {
// call method pointer
return (target->*func)(std::forward<Args>(params)...);
};
- return ffi::Function::FromTyped(fwrap, std::move(name));
}
}
-
- template <typename Func>
- TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) {
- return ffi::Function::FromTyped(std::forward<Func>(func), std::move(name));
- }
};
/// \endcond
@@ -438,6 +439,8 @@ struct init {
// Allow ObjectDef to access the execute function
template <typename Class>
friend class ObjectDef;
+ template <typename T>
+ friend class OverloadObjectDef;
/*!
* \brief Constructor
@@ -585,6 +588,9 @@ class ObjectDef : public ReflectionDefBase {
}
private:
+ template <typename T>
+ friend class OverloadObjectDef;
+
template <typename... ExtraArgs>
void RegisterExtraInfo(ExtraArgs&&... extra_args) {
TVMFFITypeMetadata info;
@@ -643,6 +649,7 @@ class ObjectDef : public ReflectionDefBase {
if (is_static) {
info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod;
}
+
// obtain the method function
Function method = GetMethod(std::string(type_key_) + "." + name,
std::forward<Func>(func));
info.method = AnyView(method).CopyToTVMFFIAny();
diff --git a/tests/cpp/test_overload.cc b/tests/cpp/test_overload.cc
new file mode 100644
index 0000000..7dfb9c7
--- /dev/null
+++ b/tests/cpp/test_overload.cc
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/object.h>
+#include <tvm/ffi/reflection/access_path.h>
+#include <tvm/ffi/reflection/accessor.h>
+#include <tvm/ffi/reflection/creator.h>
+#include <tvm/ffi/reflection/overload.h>
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/string.h>
+
+namespace {
+
+using namespace tvm::ffi;
+
+struct TestOverloadObj : public Object {
+ explicit TestOverloadObj(int32_t x) : type(Type::INT) {}
+ explicit TestOverloadObj(float y) : type(Type::FLOAT) {}
+
+ static int AddOneInt(int x) { return x + 1; }
+ static float AddOneFloat(float x) { return x + 1.0f; }
+
+ template <typename T>
+ auto Holds(T) const {
+ if constexpr (std::is_same_v<T, int32_t>) {
+ return type == Type::INT;
+ } else if constexpr (std::is_same_v<T, float>) {
+ return type == Type::FLOAT;
+ } else {
+ static_assert(sizeof(T) == 0, "Unsupported type");
+ }
+ }
+
+ enum class Type { INT, FLOAT } type;
+ TVM_FFI_DECLARE_OBJECT_INFO("test.TestOverloadObj", TestOverloadObj, Object);
+};
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::OverloadObjectDef<TestOverloadObj>()
+ .def(refl::init<int32_t>())
+ .def(refl::init<float>())
+ .def("hold_same_type", &TestOverloadObj::Holds<int32_t>)
+ .def("hold_same_type", &TestOverloadObj::Holds<float>)
+ .def_static("add_one_static", &TestOverloadObj::AddOneInt)
+ .def_static("add_one_static", &TestOverloadObj::AddOneFloat);
+}
+
+TEST(Reflection, CallOverloadedInitMethod) {
+ Function init_method = reflection::GetMethod("test.TestOverloadObj",
"__ffi_init__");
+ Any obj_a = init_method(10); // choose the int constructor
+ EXPECT_TRUE(obj_a.as<TestOverloadObj>() != nullptr);
+ EXPECT_EQ(obj_a.as<TestOverloadObj>()->type, TestOverloadObj::Type::INT);
+ Any obj_b = init_method(3.14f); // choose the float constructor
+ EXPECT_TRUE(obj_b.as<TestOverloadObj>() != nullptr);
+ EXPECT_EQ(obj_b.as<TestOverloadObj>()->type, TestOverloadObj::Type::FLOAT);
+}
+
+TEST(Reflection, CallOverloadedMethod) {
+ Function init_method = reflection::GetMethod("test.TestOverloadObj",
"__ffi_init__");
+ Function hold_same_type = reflection::GetMethod("test.TestOverloadObj",
"hold_same_type");
+ Any obj_a = init_method(10); // choose the int constructor
+ Any res_a = hold_same_type(obj_a, 20);
+ EXPECT_EQ(res_a.as<bool>(), true);
+ Any res_b = hold_same_type(obj_a, 3.14f);
+ EXPECT_EQ(res_b.as<bool>(), false);
+}
+
+TEST(Reflection, CallOverloadedStaticMethod) {
+ Function add_one = reflection::GetMethod("test.TestOverloadObj",
"add_one_static");
+ Any res_a = add_one(20);
+ EXPECT_EQ(res_a.as<int>(), 21);
+ Any res_b = add_one(1.0f);
+ static_assert(1.0f + 1.0f == 2.0f);
+ EXPECT_EQ(res_b.as<float>(), 2.0f);
+}
+
+} // namespace