This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0f5bb200f1 [FFI][REFACTOR] Enhance reflection (#18059)
0f5bb200f1 is described below
commit 0f5bb200f129ec4c0a235e065f452ca524d19bd8
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jun 15 07:46:42 2025 -0400
[FFI][REFACTOR] Enhance reflection (#18059)
This PR enhances and refactors the reflection module to support
nanobind/pybind style reflection definition.
This is one step to upgrade the overall reflection
mechanism in the project.
---
ffi/include/tvm/ffi/c_api.h | 100 ++++++++---
ffi/include/tvm/ffi/function.h | 8 +-
ffi/include/tvm/ffi/object.h | 8 +-
ffi/include/tvm/ffi/reflection/reflection.h | 265 +++++++++++++++++++++++++---
ffi/include/tvm/ffi/string.h | 2 -
ffi/include/tvm/ffi/type_traits.h | 18 +-
ffi/src/ffi/object.cc | 48 ++++-
ffi/tests/cpp/test_any.cc | 3 +
ffi/tests/cpp/test_array.cc | 1 +
ffi/tests/cpp/test_reflection.cc | 89 +++++++++-
ffi/tests/cpp/test_tuple.cc | 1 +
ffi/tests/cpp/test_variant.cc | 1 +
ffi/tests/cpp/testing_object.h | 10 +-
13 files changed, 475 insertions(+), 79 deletions(-)
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index 7cf7543f48..a43f2e83b5 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -303,12 +303,54 @@ typedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny*
result);
*/
typedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value);
+/*!
+ * \brief bitmask of the field.
+ */
+#ifdef __cplusplus
+enum TVMFFIFieldFlagBitMask : int32_t {
+#else
+typedef enum {
+#endif
+ /*! \brief The field is writable. */
+ TVMFFIFieldFlagBitMaskWritable = 1 << 0,
+ /*! \brief The field has default value. */
+ TVMFFIFieldFlagBitMaskHasDefault = 1 << 1,
+ /*! \brief The field is a static method. */
+ TVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2,
+#ifdef __cplusplus
+};
+#else
+} TVMFFIFieldFlagBitMask;
+#endif
+
/*!
* \brief Information support for optional object reflection.
*/
typedef struct {
/*! \brief The name of the field. */
TVMFFIByteArray name;
+ /*! \brief The docstring about the field. */
+ TVMFFIByteArray doc;
+ /*!
+ * \brief bitmask flags of the field.
+ */
+ int64_t flags;
+ /*!
+ * \brief Byte offset of the field.
+ */
+ int64_t byte_offset;
+ /*! \brief The getter to access the field. */
+ TVMFFIFieldGetter getter;
+ /*!
+ * \brief The setter to access the field.
+ * \note The setter is set even if the field is readonly for serialization.
+ */
+ TVMFFIFieldSetter setter;
+ /*!
+ * \brief The default value of the field, this field hold AnyView,
+ * valid when flags set TVMFFIFieldFlagBitMaskHasDefault
+ */
+ TVMFFIAny default_value;
/*!
* \brief Records the static type kind of the field.
*
@@ -317,28 +359,18 @@ typedef struct {
* - TVMFFITypeIndex::kTVMFFIObject for general objects
* - The value is nullable when kTVMFFIObject is chosen
* - static object type kinds such as Map, Dict, String
- * - POD type index
+ * - POD type index, note it does not give information about storage size of
the field.
* - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info
* about the field.
*
- * \note This information is helpful in designing serializer
- * of the field. As it helps to narrow down the type of the
- * object. It also helps to provide opportunities to enable
- * short-cut access to the field.
+ * When the value is a type index of Object type, the field is storaged as
an ObjectRef.
+ *
+ * \note This information maybe helpful in designing serializer.
+ * As it helps to narrow down the field type so we don't have to
+ * print type_key for cases like POD types.
+ * It also helps to provide opportunities to enable short-cut getter to
ObjectRef fields.
*/
int32_t field_static_type_index;
- /*!
- * \brief Mark whether field is readonly.
- */
- int32_t readonly;
- /*!
- * \brief Byte offset of the field.
- */
- int64_t byte_offset;
- /*! \brief The getter to access the field. */
- TVMFFIFieldGetter getter;
- /*! \brief The setter to access the field. */
- TVMFFIFieldSetter setter;
} TVMFFIFieldInfo;
/*!
@@ -347,11 +379,15 @@ typedef struct {
typedef struct {
/*! \brief The name of the field. */
TVMFFIByteArray name;
+ /*! \brief The docstring about the method. */
+ TVMFFIByteArray doc;
+ /*! \brief bitmask flags of the method. */
+ int64_t flags;
/*!
- * \brief The method wrapped as Function
- * \note The first argument to the method is always the self.
+ * \brief The method wrapped as ffi::Function, stored as AnyView.
+ * \note The first argument to the method is always the self for instance
methods.
*/
- TVMFFIObjectHandle method;
+ TVMFFIAny method;
} TVMFFIMethodInfo;
/*!
@@ -379,6 +415,7 @@ typedef struct {
int32_t num_fields;
/*! \brief number of reflection acccesible methods. */
int32_t num_methods;
+
/*! \brief The reflection field information. */
TVMFFIFieldInfo* fields;
/*! \brief The reflection method. */
@@ -522,12 +559,29 @@ TVM_FFI_DLL int TVMFFIEnvRegisterCAPI(const
TVMFFIByteArray* name, void* symbol)
// Section: Type reflection support APIs
//------------------------------------------------------------
/*!
- * \brief Register type field information for rutnime reflection.
+ * \brief Register type field information for runtime reflection.
* \param type_index The type index
* \param info The field info to be registered.
* \return 0 when success, nonzero when failure happens
*/
-TVM_FFI_DLL int TVMFFIRegisterTypeField(int32_t type_index, const
TVMFFIFieldInfo* info);
+TVM_FFI_DLL int TVMFFITypeRegisterField(int32_t type_index, const
TVMFFIFieldInfo* info);
+
+/*!
+ * \brief Register type method information for runtime reflection.
+ * \param type_index The type index
+ * \param info The method info to be registered.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFITypeRegisterMethod(int32_t type_index, const
TVMFFIMethodInfo* info);
+
+/*!
+ * \brief Get dynamic type info by type index.
+ *
+ * \param type_index The type index
+ * \param result The output type information
+ * \return The type info
+ */
+TVM_FFI_DLL const TVMFFITypeInfo* TVMFFITypeGetMethod(int32_t type_index);
//------------------------------------------------------------
// Section: DLPack support APIs
@@ -638,7 +692,7 @@ TVM_FFI_DLL const TVMFFIByteArray* TVMFFITraceback(const
char* filename, int lin
*
* \return 0 if success, -1 if error occured
*/
-TVM_FFI_DLL int32_t TVMFFIGetOrAllocTypeIndex(const TVMFFIByteArray* type_key,
+TVM_FFI_DLL int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key,
int32_t static_type_index,
int32_t type_depth,
int32_t num_child_slots,
int32_t child_slots_can_overflow,
diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h
index 128c67830e..eb3f390bcd 100644
--- a/ffi/include/tvm/ffi/function.h
+++ b/ffi/include/tvm/ffi/function.h
@@ -836,14 +836,14 @@ class Function::Registry {
if constexpr (std::is_base_of_v<ObjectRef, T>) {
auto fwrap = [f](T target, Args... params) -> R {
// call method pointer
- return (target.*f)(params...);
+ return (target.*f)(std::forward<Args>(params)...);
};
return Register(ffi::Function::FromTyped(fwrap, name_));
}
if constexpr (std::is_base_of_v<Object, T>) {
auto fwrap = [f](const T* target, Args... params) -> R {
// call method pointer
- return (const_cast<T*>(target)->*f)(params...);
+ return (const_cast<T*>(target)->*f)(std::forward<Args>(params)...);
};
return Register(ffi::Function::FromTyped(fwrap, name_));
}
@@ -857,14 +857,14 @@ class Function::Registry {
if constexpr (std::is_base_of_v<ObjectRef, T>) {
auto fwrap = [f](const T target, Args... params) -> R {
// call method pointer
- return (target.*f)(params...);
+ return (target.*f)(std::forward<Args>(params)...);
};
return Register(ffi::Function::FromTyped(fwrap, name_));
}
if constexpr (std::is_base_of_v<Object, T>) {
auto fwrap = [f](const T* target, Args... params) -> R {
// call method pointer
- return (target->*f)(params...);
+ return (target->*f)(std::forward<Args>(params)...);
};
return Register(ffi::Function::FromTyped(fwrap, name_));
}
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index e86689ebe2..72e6f0a1f8 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -103,6 +103,9 @@ TVM_FFI_INLINE bool IsObjectInstance(int32_t
object_type_index);
* This field is automatically set by macro TVM_DECLARE_FINAL_OBJECT_INFO
* It is still OK to sub-class a terminal object type T and construct it
using make_object.
* But IsInstance check will only show that the object type is T(instead
of the sub-class).
+ * - _type_mutable:
+ * Whether we would like to expose cast to non-constant pointer
+ * ObjectType* from Any/AnyView. By default, we set to false so it is not
exposed.
*
* The following two fields are necessary for base classes that can be
sub-classed.
*
@@ -191,6 +194,7 @@ class Object {
// Default object type properties for sub-classes
static constexpr bool _type_final = false;
+ static constexpr bool _type_mutable = false;
static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true;
// NOTE: static type index field of the class
@@ -546,7 +550,7 @@ struct ObjectPtrEqual {
"Need to set _type_child_slots when parent specifies it.");
\
TVMFFIByteArray type_key{TypeName::_type_key,
\
std::char_traits<char>::length(TypeName::_type_key)}; \
- static int32_t tindex = TVMFFIGetOrAllocTypeIndex(
\
+ static int32_t tindex = TVMFFITypeGetOrAllocIndex(
\
&type_key, TypeName::_type_index, TypeName::_type_depth,
TypeName::_type_child_slots, \
TypeName::_type_child_slots_can_overflow,
ParentType::_GetOrAllocRuntimeTypeIndex()); \
return tindex;
\
@@ -576,7 +580,7 @@ struct ObjectPtrEqual {
"Need to set _type_child_slots when parent specifies it.");
\
TVMFFIByteArray type_key{TypeName::_type_key,
\
std::char_traits<char>::length(TypeName::_type_key)}; \
- static int32_t tindex = TVMFFIGetOrAllocTypeIndex(
\
+ static int32_t tindex = TVMFFITypeGetOrAllocIndex(
\
&type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots,
\
TypeName::_type_child_slots_can_overflow,
ParentType::_GetOrAllocRuntimeTypeIndex()); \
return tindex;
\
diff --git a/ffi/include/tvm/ffi/reflection/reflection.h
b/ffi/include/tvm/ffi/reflection/reflection.h
index 766b9b8099..a5ab6f4fe8 100644
--- a/ffi/include/tvm/ffi/reflection/reflection.h
+++ b/ffi/include/tvm/ffi/reflection/reflection.h
@@ -29,10 +29,31 @@
#include <tvm/ffi/type_traits.h>
#include <string>
+#include <utility>
namespace tvm {
namespace ffi {
-namespace details {
+/*! \brief Reflection namespace */
+namespace reflection {
+
+/*! \brief Trait that can be used to set field info */
+struct FieldInfoTrait {};
+
+/*!
+ * \brief Trait that can be used to set field default value
+ */
+class DefaultValue : public FieldInfoTrait {
+ public:
+ explicit DefaultValue(Any value) : value_(value) {}
+
+ void Apply(TVMFFIFieldInfo* info) const {
+ info->default_value = AnyView(value_).CopyToTVMFFIAny();
+ info->flags |= TVMFFIFieldFlagBitMaskHasDefault;
+ }
+
+ private:
+ Any value_;
+};
/*!
* \brief Get the byte offset of a class member field.
@@ -50,37 +71,108 @@ inline int64_t GetFieldByteOffsetToObject(T
Class::*field_ptr) {
return field_offset_to_class -
details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
}
-struct ReflectionDefFinish {};
-
class ReflectionDef {
public:
- explicit ReflectionDef(int32_t type_index) : type_index_(type_index) {}
+ explicit ReflectionDef(int32_t type_index, const char* type_key)
+ : type_index_(type_index), type_key_(type_key) {}
+
+ /*!
+ * \brief Define a readonly field.
+ *
+ * \tparam Class The class type.
+ * \tparam T The field type.
+ * \tparam Extra The extra arguments.
+ *
+ * \param name The name of the field.
+ * \param field_ptr The pointer to the field.
+ * \param extra The extra arguments that can be docstring or default value.
+ *
+ * \return The reflection definition.
+ */
+ template <typename Class, typename T, typename... Extra>
+ ReflectionDef& def_ro(const char* name, T Class::*field_ptr, Extra&&...
extra) {
+ RegisterField(name, field_ptr, false, std::forward<Extra>(extra)...);
+ return *this;
+ }
- template <typename Class, typename T>
- ReflectionDef& def_readonly(const char* name, T Class::*field_ptr) {
- RegisterField(name, field_ptr, true);
+ /*!
+ * \brief Define a read-write field.
+ *
+ * \tparam Class The class type.
+ * \tparam T The field type.
+ * \tparam Extra The extra arguments.
+ *
+ * \param name The name of the field.
+ * \param field_ptr The pointer to the field.
+ * \param extra The extra arguments that can be docstring or default value.
+ *
+ * \return The reflection definition.
+ */
+ template <typename Class, typename T, typename... Extra>
+ ReflectionDef& def_rw(const char* name, T Class::*field_ptr, Extra&&...
extra) {
+ RegisterField(name, field_ptr, true, std::forward<Extra>(extra)...);
return *this;
}
- template <typename Class, typename T>
- ReflectionDef& def_readwrite(const char* name, T Class::*field_ptr) {
- RegisterField(name, field_ptr, false);
+ /*!
+ * \brief Define a method.
+ *
+ * \tparam Func The function type.
+ * \tparam Extra The extra arguments.
+ *
+ * \param name The name of the method.
+ * \param func The function to be registered.
+ * \param extra The extra arguments that can be docstring.
+ *
+ * \return The reflection definition.
+ */
+ template <typename Func, typename... Extra>
+ ReflectionDef& def(const char* name, Func&& func, Extra&&... extra) {
+ RegisterMethod(name, false, std::forward<Func>(func),
std::forward<Extra>(extra)...);
+ return *this;
+ }
+
+ /*!
+ * \brief Define a static method.
+ *
+ * \tparam Func The function type.
+ * \tparam Extra The extra arguments.
+ *
+ * \param name The name of the method.
+ * \param func The function to be registered.
+ * \param extra The extra arguments that can be docstring.
+ *
+ * \return The reflection definition.
+ */
+ template <typename Func, typename... Extra>
+ ReflectionDef& def_static(const char* name, Func&& func, Extra&&... extra) {
+ RegisterMethod(name, true, std::forward<Func>(func),
std::forward<Extra>(extra)...);
return *this;
}
private:
- template <typename Class, typename T>
- void RegisterField(const char* name, T Class::*field_ptr, bool readonly) {
+ template <typename Class, typename T, typename... ExtraArgs>
+ void RegisterField(const char* name, T Class::*field_ptr, bool writable,
+ ExtraArgs&&... extra_args) {
TVMFFIFieldInfo info;
info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
info.field_static_type_index = TypeToFieldStaticTypeIndex<T>::value;
// store byte offset and setter, getter
// so the same setter can be reused for all the same type
info.byte_offset = GetFieldByteOffsetToObject<Class, T>(field_ptr);
- info.readonly = readonly;
+ info.flags = 0;
+ if (writable) {
+ info.flags |= TVMFFIFieldFlagBitMaskWritable;
+ }
info.getter = FieldGetter<T>;
info.setter = FieldSetter<T>;
- TVM_FFI_CHECK_SAFE_CALL(TVMFFIRegisterTypeField(type_index_, &info));
+ // initialize default value to nullptr
+ info.default_value = AnyView(nullptr).CopyToTVMFFIAny();
+ info.doc = TVMFFIByteArray{nullptr, 0};
+ // apply field info traits
+ ((ApplyFieldInfoTrait(&info, std::forward<ExtraArgs>(extra_args)), ...));
+ // call register
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info));
}
template <typename T>
@@ -97,14 +189,70 @@ class ReflectionDef {
TVM_FFI_SAFE_CALL_END();
}
+ template <typename T>
+ static void ApplyFieldInfoTrait(TVMFFIFieldInfo* info, const T& value) {
+ if constexpr (std::is_base_of_v<FieldInfoTrait, std::decay_t<T>>) {
+ value.Apply(info);
+ }
+ if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
+ info->doc = TVMFFIByteArray{value,
std::char_traits<char>::length(value)};
+ }
+ }
+
+ // register a method
+ template <typename Func, typename... Extra>
+ void RegisterMethod(const char* name, bool is_static, Func&& func,
Extra&&... extra) {
+ TVMFFIMethodInfo info;
+ info.name = TVMFFIByteArray{name, std::char_traits<char>::length(name)};
+ info.doc = TVMFFIByteArray{nullptr, 0};
+ info.flags = 0;
+ if (is_static) {
+ info.flags |= TVMFFIFieldFlagBitMaskIsStaticMethod;
+ }
+ // obtain the method function
+ Function method = GetMethod(std::string(type_key_) + "." + name,
std::forward<Func>(func));
+ info.method = AnyView(method).CopyToTVMFFIAny();
+ // apply method info traits
+ ((ApplyMethodInfoTrait(&info, std::forward<Extra>(extra)), ...));
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info));
+ }
+
+ template <typename T>
+ static void ApplyMethodInfoTrait(TVMFFIMethodInfo* info, const T& value) {
+ if constexpr (std::is_same_v<std::decay_t<T>, char*>) {
+ info->doc = TVMFFIByteArray{value,
std::char_traits<char>::length(value)};
+ }
+ }
+
+ template <typename Class, typename R, typename... Args>
+ static Function GetMethod(std::string name, R (Class::*func)(Args...)) {
+ auto fwrap = [func](const Class* target, Args... params) -> R {
+ return
(const_cast<Class*>(target)->*func)(std::forward<Args>(params)...);
+ };
+ return ffi::Function::FromTyped(fwrap, name);
+ }
+
+ template <typename Class, typename R, typename... Args>
+ static Function GetMethod(std::string name, R (Class::*func)(Args...) const)
{
+ auto fwrap = [func](const Class* target, Args... params) -> R {
+ return (target->*func)(std::forward<Args>(params)...);
+ };
+ return ffi::Function::FromTyped(fwrap, name);
+ }
+
+ template <typename Func>
+ static Function GetMethod(std::string name, Func&& func) {
+ return ffi::Function::FromTyped(std::forward<Func>(func), name);
+ }
+
int32_t type_index_;
+ const char* type_key_;
};
/*!
* \brief helper function to get reflection field info by type key and field
name
*/
-inline const TVMFFIFieldInfo* GetReflectionFieldInfo(std::string_view type_key,
- const char* field_name) {
+inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view type_key, const
char* field_name) {
int32_t type_index;
TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()};
TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index));
@@ -115,14 +263,18 @@ inline const TVMFFIFieldInfo*
GetReflectionFieldInfo(std::string_view type_key,
}
}
TVM_FFI_THROW(RuntimeError) << "Cannot find field " << field_name << " in "
<< type_key;
+ TVM_FFI_UNREACHABLE();
}
/*!
* \brief helper wrapper class to obtain a getter.
*/
-class ReflectionFieldGetter {
+class FieldGetter {
public:
- explicit ReflectionFieldGetter(const TVMFFIFieldInfo* field_info) :
field_info_(field_info) {}
+ explicit FieldGetter(const TVMFFIFieldInfo* field_info) :
field_info_(field_info) {}
+
+ explicit FieldGetter(std::string_view type_key, const char* field_name)
+ : FieldGetter(GetFieldInfo(type_key, field_name)) {}
Any operator()(const Object* obj_ptr) const {
Any result;
@@ -140,16 +292,81 @@ class ReflectionFieldGetter {
const TVMFFIFieldInfo* field_info_;
};
-#define TVM_FFI_REFLECTION_REG_VAR_DEF \
- static inline TVM_FFI_ATTRIBUTE_UNUSED ::tvm::ffi::details::ReflectionDef&
__TVMFFIReflectionReg
+/*!
+ * \brief helper wrapper class to obtain a setter.
+ */
+class FieldSetter {
+ public:
+ explicit FieldSetter(const TVMFFIFieldInfo* field_info) :
field_info_(field_info) {}
+
+ explicit FieldSetter(std::string_view type_key, const char* field_name)
+ : FieldSetter(GetFieldInfo(type_key, field_name)) {}
+
+ void operator()(const Object* obj_ptr, AnyView value) const {
+ const void* addr = reinterpret_cast<const char*>(obj_ptr) +
field_info_->byte_offset;
+ TVM_FFI_CHECK_SAFE_CALL(
+ field_info_->setter(const_cast<void*>(addr), reinterpret_cast<const
TVMFFIAny*>(&value)));
+ }
+
+ void operator()(const ObjectPtr<Object>& obj_ptr, AnyView value) const {
+ operator()(obj_ptr.get(), value);
+ }
+
+ void operator()(const ObjectRef& obj, AnyView value) const {
operator()(obj.get(), value); }
+
+ private:
+ const TVMFFIFieldInfo* field_info_;
+};
+
+/*!
+ * \brief helper function to get reflection method info by type key and method
name
+ *
+ * \param type_key The type key.
+ * \param method_name The name of the method.
+ * \return The method info.
+ */
+inline const TVMFFIMethodInfo* GetMethodInfo(std::string_view type_key, const
char* method_name) {
+ int32_t type_index;
+ TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()};
+ TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index));
+ const TypeInfo* info = TVMFFIGetTypeInfo(type_index);
+ for (int32_t i = 0; i < info->num_methods; ++i) {
+ if (std::strncmp(info->methods[i].name.data, method_name,
info->methods[i].name.size) == 0) {
+ return &(info->methods[i]);
+ }
+ }
+ TVM_FFI_THROW(RuntimeError) << "Cannot find method " << method_name << " in
" << type_key;
+ TVM_FFI_UNREACHABLE();
+}
+
+/*!
+ * \brief helper function to get reflection method function by method info
+ *
+ * \param type_key The type key.
+ * \param method_name The name of the method.
+ * \return The method function.
+ */
+inline Function GetMethod(std::string_view type_key, const char* method_name) {
+ const TVMFFIMethodInfo* info = GetMethodInfo(type_key, method_name);
+ return AnyView::CopyFromTVMFFIAny(info->method).cast<Function>();
+}
+
+#define TVM_FFI_REFLECTION_REG_VAR_DEF
\
+ static inline TVM_FFI_ATTRIBUTE_UNUSED
::tvm::ffi::reflection::ReflectionDef& \
+ __TVMFFIReflectionReg
/*!
* helper macro to define a reflection definition for an object
*/
-#define TVM_FFI_REFLECTION_DEF(TypeName) \
- TVM_FFI_STR_CONCAT(TVM_FFI_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
-
::tvm::ffi::details::ReflectionDef(TypeName::_GetOrAllocRuntimeTypeIndex())
-} // namespace details
+#define TVM_FFI_REFLECTION_DEF(TypeName)
\
+ TVM_FFI_STR_CONCAT(TVM_FFI_REFLECTION_REG_VAR_DEF, __COUNTER__) =
\
+
::tvm::ffi::reflection::ReflectionDef(TypeName::_GetOrAllocRuntimeTypeIndex(), \
+ TypeName::_type_key)
+
+} // namespace reflection
+
+/*! \brief Shortcut to the reflection namespace */
+namespace refl = reflection;
} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_REFLECTION_REFLECTION_H_
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
index c3eceff905..19df2e8e3d 100644
--- a/ffi/include/tvm/ffi/string.h
+++ b/ffi/include/tvm/ffi/string.h
@@ -401,7 +401,6 @@ TVM_FFI_INLINE std::string_view
ToStringView(TVMFFIByteArray str) {
template <int N>
struct TypeTraits<char[N]> : public TypeTraitsBase {
// NOTE: only enable implicit conversion into AnyView
- static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIRawStr;
static constexpr bool storage_enabled = false;
static TVM_FFI_INLINE void CopyToAnyView(const char src[N], TVMFFIAny*
result) {
@@ -417,7 +416,6 @@ struct TypeTraits<char[N]> : public TypeTraitsBase {
template <>
struct TypeTraits<const char*> : public TypeTraitsBase {
- static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIRawStr;
static constexpr bool storage_enabled = false;
static TVM_FFI_INLINE void CopyToAnyView(const char* src, TVMFFIAny* result)
{
diff --git a/ffi/include/tvm/ffi/type_traits.h
b/ffi/include/tvm/ffi/type_traits.h
index 02c9a90edc..5c291b5535 100644
--- a/ffi/include/tvm/ffi/type_traits.h
+++ b/ffi/include/tvm/ffi/type_traits.h
@@ -84,6 +84,7 @@ inline constexpr bool use_default_type_traits_v = true;
struct TypeTraitsBase {
static constexpr bool convert_enabled = true;
static constexpr bool storage_enabled = true;
+ static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny;
// get mismatched type when result mismatches the trait.
// this function is called after TryCastFromAnyView fails
// to get more detailed type information in runtime
@@ -588,17 +589,18 @@ struct ObjectRefWithFallbackTraitsBase : public
ObjectRefTypeTraitsBase<ObjectRe
// Traits for weak pointer of object
// NOTE: we require the weak pointer cast from
+
template <typename TObject>
-struct TypeTraits<const TObject*, std::enable_if_t<std::is_base_of_v<Object,
TObject>>>
+struct TypeTraits<TObject*, std::enable_if_t<std::is_base_of_v<Object,
TObject>>>
: public TypeTraitsBase {
- static TVM_FFI_INLINE void CopyToAnyView(const TObject* src, TVMFFIAny*
result) {
+ static TVM_FFI_INLINE void CopyToAnyView(TObject* src, TVMFFIAny* result) {
TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src);
result->type_index = obj_ptr->type_index;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
result->v_obj = obj_ptr;
}
- static TVM_FFI_INLINE void MoveToAny(const TObject* src, TVMFFIAny* result) {
+ static TVM_FFI_INLINE void MoveToAny(TObject* src, TVMFFIAny* result) {
TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src);
result->type_index = obj_ptr->type_index;
TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
@@ -612,11 +614,17 @@ struct TypeTraits<const TObject*,
std::enable_if_t<std::is_base_of_v<Object, TOb
details::IsObjectInstance<TObject>(src->type_index);
}
- static TVM_FFI_INLINE const TObject* CopyFromAnyViewAfterCheck(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE TObject* CopyFromAnyViewAfterCheck(const TVMFFIAny*
src) {
+ if constexpr (!std::is_const_v<TObject>) {
+ static_assert(TObject::_type_mutable, "TObject must be mutable to enable
cast from Any");
+ }
return details::ObjectUnsafe::RawObjectPtrFromUnowned<TObject>(src->v_obj);
}
- static TVM_FFI_INLINE std::optional<const TObject*> TryCastFromAnyView(const
TVMFFIAny* src) {
+ static TVM_FFI_INLINE std::optional<TObject*> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if constexpr (!std::is_const_v<TObject>) {
+ static_assert(TObject::_type_mutable, "TObject must be mutable to enable
cast from Any");
+ }
if (CheckAnyStrict(src)) return CopyFromAnyViewAfterCheck(src);
return std::nullopt;
}
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index 63ec68790e..6ce149a7c9 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -54,6 +54,8 @@ class TypeTable {
std::vector<int32_t> type_acenstors_data;
/*! \brief type fields informaton */
std::vector<TVMFFIFieldInfo> type_fields_data;
+ /*! \brief type methods informaton */
+ std::vector<TVMFFIMethodInfo> type_methods_data;
// NOTE: the indices in [index, index + num_reserved_slots) are
// reserved for the child-class of this type.
/*! \brief Total number of slots reserved for the type and its children. */
@@ -186,12 +188,29 @@ class TypeTable {
Entry* entry = GetTypeEntry(type_index);
TVMFFIFieldInfo field_data = *info;
field_data.name = this->CopyString(info->name);
+ field_data.doc = this->CopyString(info->doc);
+ if (info->flags & TVMFFIFieldFlagBitMaskHasDefault) {
+ field_data.default_value =
+
this->CopyAny(AnyView::CopyFromTVMFFIAny(info->default_value)).CopyToTVMFFIAny();
+ } else {
+ field_data.default_value = AnyView(nullptr).CopyToTVMFFIAny();
+ }
entry->type_fields_data.push_back(field_data);
// refresh ptr as the data can change
entry->fields = entry->type_fields_data.data();
entry->num_fields = static_cast<int32_t>(entry->type_fields_data.size());
}
+ void RegisterTypeMethod(int32_t type_index, const TVMFFIMethodInfo* info) {
+ Entry* entry = GetTypeEntry(type_index);
+ TVMFFIMethodInfo method_data = *info;
+ method_data.name = this->CopyString(info->name);
+ method_data.doc = this->CopyString(info->doc);
+ method_data.method =
this->CopyAny(AnyView::CopyFromTVMFFIAny(info->method)).CopyToTVMFFIAny();
+ entry->type_methods_data.push_back(method_data);
+ entry->methods = entry->type_methods_data.data();
+ entry->num_methods = static_cast<int32_t>(entry->type_methods_data.size());
+ }
void Dump(int min_children_count) {
std::vector<int> num_children(type_table_.size(), 0);
// expected child slots compute the expected slots
@@ -262,16 +281,25 @@ class TypeTable {
}
TVMFFIByteArray CopyString(TVMFFIByteArray str) {
- std::unique_ptr<std::string> val = std::make_unique<std::string>(str.data,
str.size);
- TVMFFIByteArray c_val{val->data(), val->length()};
- string_pool_.emplace_back(std::move(val));
+ if (str.size == 0) {
+ return TVMFFIByteArray{nullptr, 0};
+ }
+ String val = String(str.data, str.size);
+ TVMFFIByteArray c_val{val.data(), val.length()};
+ any_pool_.emplace_back(std::move(val));
return c_val;
}
+ AnyView CopyAny(Any val) {
+ AnyView view = AnyView(val);
+ any_pool_.emplace_back(std::move(val));
+ return view;
+ }
+
int32_t type_counter_{TypeIndex::kTVMFFIDynObjectBegin};
std::vector<std::unique_ptr<Entry>> type_table_;
std::unordered_map<std::string, int32_t> type_key2index_;
- std::vector<std::unique_ptr<std::string>> string_pool_;
+ std::vector<Any> any_pool_;
};
} // namespace ffi
} // namespace tvm
@@ -288,13 +316,19 @@ int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key,
int32_t* out_tindex) {
TVM_FFI_SAFE_CALL_END();
}
-int TVMFFIRegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) {
+int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo* info) {
TVM_FFI_SAFE_CALL_BEGIN();
tvm::ffi::TypeTable::Global()->RegisterTypeField(type_index, info);
TVM_FFI_SAFE_CALL_END();
}
-int32_t TVMFFIGetOrAllocTypeIndex(const TVMFFIByteArray* type_key, int32_t
static_type_index,
+int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo* info)
{
+ TVM_FFI_SAFE_CALL_BEGIN();
+ tvm::ffi::TypeTable::Global()->RegisterTypeMethod(type_index, info);
+ TVM_FFI_SAFE_CALL_END();
+}
+
+int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray* type_key, int32_t
static_type_index,
int32_t type_depth, int32_t num_child_slots,
int32_t child_slots_can_overflow, int32_t
parent_type_index) {
TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
@@ -302,7 +336,7 @@ int32_t TVMFFIGetOrAllocTypeIndex(const TVMFFIByteArray*
type_key, int32_t stati
return tvm::ffi::TypeTable::Global()->GetOrAllocTypeIndex(
s_type_key, static_type_index, type_depth, num_child_slots,
child_slots_can_overflow,
parent_type_index);
- TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetOrAllocTypeIndex);
+ TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFITypeGetOrAllocIndex);
}
const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) {
diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc
index 3ad81cd118..eea18c7c64 100644
--- a/ffi/tests/cpp/test_any.cc
+++ b/ffi/tests/cpp/test_any.cc
@@ -232,6 +232,9 @@ TEST(Any, Object) {
EXPECT_EQ(v1.use_count(), 3);
EXPECT_TRUE(any2.as<TInt>().has_value());
+ any2 = const_cast<TIntObj*>(v1_ptr);
+ EXPECT_TRUE(any2.as<TInt>().has_value());
+
// convert to raw opaque ptr
void* raw_v1_ptr = const_cast<TIntObj*>(v1_ptr);
any2 = raw_v1_ptr;
diff --git a/ffi/tests/cpp/test_array.cc b/ffi/tests/cpp/test_array.cc
index cb42f32c6c..321af7ae16 100644
--- a/ffi/tests/cpp/test_array.cc
+++ b/ffi/tests/cpp/test_array.cc
@@ -18,6 +18,7 @@
*/
#include <gtest/gtest.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/function.h>
#include "./testing_object.h"
diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc
index fec167d257..76e8d35a99 100644
--- a/ffi/tests/cpp/test_reflection.cc
+++ b/ffi/tests/cpp/test_reflection.cc
@@ -20,33 +20,108 @@
#include <gtest/gtest.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/reflection/reflection.h>
+#include <tvm/ffi/string.h>
#include "./testing_object.h"
namespace {
-
using namespace tvm::ffi;
using namespace tvm::ffi::testing;
+TVM_FFI_REFLECTION_DEF(TFloatObj)
+ .def_rw("value", &TFloatObj::value, "float value field",
refl::DefaultValue(10.0))
+ .def("sub", [](const TFloatObj* self, double other) -> double { return
self->value - other; })
+ .def("add", &TFloatObj::Add, "add method");
+
+TVM_FFI_REFLECTION_DEF(TIntObj)
+ .def_ro("value", &TIntObj::value)
+ .def_static("static_add", &TInt::StaticAdd, "static add method");
+
+TVM_FFI_REFLECTION_DEF(TPrimExprObj)
+ .def_ro("dtype", &TPrimExprObj::dtype, "dtype field",
refl::DefaultValue("float"))
+ .def_ro("value", &TPrimExprObj::value, "value field",
refl::DefaultValue(0))
+ .def("sub", [](TPrimExprObj* self, double other) -> double {
+ // this is ok because TPrimExprObj is declared asmutable
+ return self->value - other;
+ });
+
struct A : public Object {
int64_t x;
int64_t y;
};
+TVM_FFI_REFLECTION_DEF(A).def_ro("x", &A::x).def_rw("y", &A::y);
+
TEST(Reflection, GetFieldByteOffset) {
- EXPECT_EQ(details::GetFieldByteOffsetToObject(&A::x), sizeof(TVMFFIObject));
- EXPECT_EQ(details::GetFieldByteOffsetToObject(&A::y), 8 +
sizeof(TVMFFIObject));
- EXPECT_EQ(details::GetFieldByteOffsetToObject(&TIntObj::value),
sizeof(TVMFFIObject));
+ EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&A::x),
sizeof(TVMFFIObject));
+ EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&A::y), 8 +
sizeof(TVMFFIObject));
+ EXPECT_EQ(reflection::GetFieldByteOffsetToObject(&TIntObj::value),
sizeof(TVMFFIObject));
}
TEST(Reflection, FieldGetter) {
ObjectRef a = TInt(10);
- details::ReflectionFieldGetter
getter(details::GetReflectionFieldInfo("test.Int", "value"));
+ reflection::FieldGetter getter("test.Int", "value");
EXPECT_EQ(getter(a).cast<int>(), 10);
ObjectRef b = TFloat(10.0);
- details::ReflectionFieldGetter getter_float(
- details::GetReflectionFieldInfo("test.Float", "value"));
+ reflection::FieldGetter getter_float("test.Float", "value");
EXPECT_EQ(getter_float(b).cast<double>(), 10.0);
}
+
+TEST(Reflection, FieldSetter) {
+ ObjectRef a = TFloat(10.0);
+ reflection::FieldSetter setter("test.Float", "value");
+ setter(a, 20.0);
+ EXPECT_EQ(a.as<TFloatObj>()->value, 20.0);
+}
+
+TEST(Reflection, FieldInfo) {
+ const TVMFFIFieldInfo* info_int = reflection::GetFieldInfo("test.Int",
"value");
+ EXPECT_FALSE(info_int->flags & TVMFFIFieldFlagBitMaskHasDefault);
+ EXPECT_FALSE(info_int->flags & TVMFFIFieldFlagBitMaskWritable);
+ EXPECT_EQ(Bytes(info_int->doc).operator std::string(), "");
+
+ const TVMFFIFieldInfo* info_float = reflection::GetFieldInfo("test.Float",
"value");
+ EXPECT_EQ(info_float->default_value.v_float64, 10.0);
+ EXPECT_TRUE(info_float->flags & TVMFFIFieldFlagBitMaskHasDefault);
+ EXPECT_TRUE(info_float->flags & TVMFFIFieldFlagBitMaskWritable);
+ EXPECT_EQ(Bytes(info_float->doc).operator std::string(), "float value
field");
+
+ const TVMFFIFieldInfo* info_prim_expr_dtype =
reflection::GetFieldInfo("test.PrimExpr", "dtype");
+ AnyView default_value =
AnyView::CopyFromTVMFFIAny(info_prim_expr_dtype->default_value);
+ EXPECT_EQ(default_value.cast<String>(), "float");
+ EXPECT_EQ(default_value.as<String>().value().use_count(), 2);
+ EXPECT_TRUE(info_prim_expr_dtype->flags & TVMFFIFieldFlagBitMaskHasDefault);
+ EXPECT_FALSE(info_prim_expr_dtype->flags & TVMFFIFieldFlagBitMaskWritable);
+ EXPECT_EQ(Bytes(info_prim_expr_dtype->doc).operator std::string(), "dtype
field");
+}
+
+TEST(Reflection, MethodInfo) {
+ const TVMFFIMethodInfo* info_int_static_add =
reflection::GetMethodInfo("test.Int", "static_add");
+ EXPECT_TRUE(info_int_static_add->flags &
TVMFFIFieldFlagBitMaskIsStaticMethod);
+ EXPECT_EQ(Bytes(info_int_static_add->doc).operator std::string(), "static
add method");
+
+ const TVMFFIMethodInfo* info_float_add =
reflection::GetMethodInfo("test.Float", "add");
+ EXPECT_FALSE(info_float_add->flags & TVMFFIFieldFlagBitMaskIsStaticMethod);
+ EXPECT_EQ(Bytes(info_float_add->doc).operator std::string(), "add method");
+
+ const TVMFFIMethodInfo* info_float_sub =
reflection::GetMethodInfo("test.Float", "sub");
+ EXPECT_FALSE(info_float_sub->flags & TVMFFIFieldFlagBitMaskIsStaticMethod);
+ EXPECT_EQ(Bytes(info_float_sub->doc).operator std::string(), "");
+}
+
+TEST(Reflection, CallMethod) {
+ Function static_int_add = reflection::GetMethod("test.Int", "static_add");
+ EXPECT_EQ(static_int_add(TInt(1), TInt(2)).cast<TInt>()->value, 3);
+
+ Function float_add = reflection::GetMethod("test.Float", "add");
+ EXPECT_EQ(float_add(TFloat(1), 2.0).cast<double>(), 3.0);
+
+ Function float_sub = reflection::GetMethod("test.Float", "sub");
+ EXPECT_EQ(float_sub(TFloat(1), 2.0).cast<double>(), -1.0);
+
+ Function prim_expr_sub = reflection::GetMethod("test.PrimExpr", "sub");
+ EXPECT_EQ(prim_expr_sub(TPrimExpr("float", 1), 2.0).cast<double>(), -1.0);
+}
+
} // namespace
diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc
index e0f69d8200..5735e86eca 100644
--- a/ffi/tests/cpp/test_tuple.cc
+++ b/ffi/tests/cpp/test_tuple.cc
@@ -18,6 +18,7 @@
*/
#include <gtest/gtest.h>
#include <tvm/ffi/container/tuple.h>
+#include <tvm/ffi/function.h>
#include "./testing_object.h"
diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc
index 451913c992..b140e7db6e 100644
--- a/ffi/tests/cpp/test_variant.cc
+++ b/ffi/tests/cpp/test_variant.cc
@@ -20,6 +20,7 @@
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/variant.h>
+#include <tvm/ffi/function.h>
#include <tvm/ffi/memory.h>
#include "./testing_object.h"
diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h
index 69a91efc46..8a91848845 100644
--- a/ffi/tests/cpp/testing_object.h
+++ b/ffi/tests/cpp/testing_object.h
@@ -22,7 +22,6 @@
#include <tvm/ffi/memory.h>
#include <tvm/ffi/object.h>
-#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ffi/string.h>
namespace tvm {
@@ -65,12 +64,12 @@ class TIntObj : public TNumberObj {
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj);
};
-TVM_FFI_REFLECTION_DEF(TIntObj).def_readonly("value", &TIntObj::value);
-
class TInt : public TNumber {
public:
explicit TInt(int64_t value) { data_ = make_object<TIntObj>(value); }
+ static TInt StaticAdd(TInt lhs, TInt rhs) { return TInt(lhs->value +
rhs->value); }
+
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TInt, TNumber, TIntObj);
};
@@ -80,12 +79,12 @@ class TFloatObj : public TNumberObj {
TFloatObj(double value) : value(value) {}
+ double Add(double other) const { return value + other; }
+
static constexpr const char* _type_key = "test.Float";
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TFloatObj, TNumberObj);
};
-TVM_FFI_REFLECTION_DEF(TFloatObj).def_readonly("value", &TFloatObj::value);
-
class TFloat : public TNumber {
public:
explicit TFloat(double value) { data_ = make_object<TFloatObj>(value); }
@@ -102,6 +101,7 @@ class TPrimExprObj : public Object {
TPrimExprObj(std::string dtype, double value) : dtype(dtype), value(value) {}
static constexpr const char* _type_key = "test.PrimExpr";
+ static constexpr bool _type_mutable = true;
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TPrimExprObj, Object);
};