This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 89ccf3f9677d0e8443f379c26fe220a4fbd91d78 Author: tqchen <[email protected]> AuthorDate: Wed Aug 14 14:53:14 2024 -0400 [FFI] Object type hierachy cast and check support Co-authored-by: Junru Shao <[email protected]> --- ffi/include/tvm/ffi/c_api.h | 26 +++- ffi/include/tvm/ffi/error.h | 2 +- ffi/include/tvm/ffi/object.h | 201 +++++++++++++++++++++-------- ffi/src/ffi/object.cc | 265 +++++++++++++++++++++------------------ ffi/tests/example/test_error.cc | 1 - ffi/tests/example/test_object.cc | 93 +++++++++++++- 6 files changed, 410 insertions(+), 178 deletions(-) diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 2616ae1648..e0597d0dda 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -85,8 +85,7 @@ typedef enum { // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo) // kTVMFFIDynObject is used to indicate that the type index // is dynamic and needs to be looked up at runtime - kTVMFFIDynObject = 128, - kTVMFFIDynObjectBegin = 129 + kTVMFFIDynObjectBegin = 128 #ifdef __cplusplus }; #else @@ -142,6 +141,29 @@ typedef struct { const char* bytes; } TVMFFIByteArray; +/*! + * \brief Runtime type information for object type checking. + */ +typedef struct { + /*! + *\brief The runtime type index, + * It can be allocated during runtime if the type is dynamic. + */ + int32_t type_index; + /*! \brief number of parent types in the type hierachy. */ + int32_t type_depth; + /*! \brief the unique type key to identify the type. */ + const char* type_key; + /*! \brief Cached hash value of the type key, used for consistent structural hashing. */ + uint64_t type_key_hash; + /*! + * \brief type_acenstors[depth] stores the type_index of the acenstors at depth level + * \note To keep things simple, we do not allow multiple inheritance so the + * hieracy stays as a tree + */ + const int32_t* type_acenstors; +} TVMFFITypeInfo; + #ifdef __cplusplus } // TVM_FFI_EXTERN_C #endif diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h index 250206ac8b..f894166ef3 100644 --- a/ffi/include/tvm/ffi/error.h +++ b/ffi/include/tvm/ffi/error.h @@ -95,7 +95,7 @@ class Error : return get()->what_str.c_str(); } - TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Error, ObjectRef, ErrorObj) + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Error, ObjectRef, ErrorObj); }; namespace details { diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 03ba7d7dbf..5a6b552e1c 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -33,10 +33,44 @@ namespace tvm { namespace ffi { using TypeIndex = TVMFFITypeIndex; +using TypeInfo = TVMFFITypeInfo; namespace details { // forward declare object internal struct ObjectInternal; + +// Code section that depends on dynamic components +#if TVM_FFI_ALLOW_DYN_TYPE +/*! + * \brief Initialize the type info during runtime. + * + * When the function is first time called for a type, + * it will register the type to the type table in the runtime. + * + * If the static_tindex is non-negative, the function will + * allocate a runtime type index. + * Otherwise, we will populate the type table and return the static index. + * + * \param type_key The type key. + * \param static_type_index Static type index if any, can be -1, which means this is a dynamic index + * \param num_child_slots Number of slots reserved for its children. + * \param child_slots_can_overflow Whether to allow child to overflow the slots. + * \param parent_type_index Parent type index, pass in -1 if it is root. + * + * \return The allocated type index + */ +TVM_FFI_DLL int32_t ObjectGetOrAllocTypeIndex(const char* type_key, int32_t static_type_index, + int32_t type_depth, int32_t num_child_slots, + bool child_slots_can_overflow, + int32_t parent_type_index); + +/*! + * \brief Get Type information from type index. + * \param type_index The type index + * \return The type information + */ +TVM_FFI_DLL const TypeInfo* ObjectGetTypeInfo(int32_t type_index); +#endif // TVM_FFI_ALLOW_DYN_TYPE } // namespace details /*! @@ -89,6 +123,47 @@ class Object { header_.deleter = nullptr; } + /*! + * Check if the object is an instance of TargetType. + * \tparam TargetType The target type to be checked. + * \return Whether the target type is true. + */ + template <typename TargetType> + bool IsInstance() const { + // Everything is a subclass of object. + if constexpr (std::is_same<TargetType, Object>::value) return true; + + if constexpr (TargetType::_type_final) { + // if the target type is a final type + // then we only need to check the equivalence. + return header_.type_index == TargetType::RuntimeTypeIndex(); + } + + // if target type is a non-leaf type + // Check if type index falls into the range of reserved slots. + int32_t target_type_index = TargetType::RuntimeTypeIndex(); + int32_t begin = target_type_index; + // The condition will be optimized by constant-folding. + if constexpr (TargetType::_type_child_slots != 0) { + int32_t end = begin + TargetType::_type_child_slots; + if (header_.type_index >= begin && header_.type_index < end) return true; + } else { + if (header_.type_index == begin) return true; + } + if (!TargetType::_type_child_slots_can_overflow) return false; + // Invariance: parent index is always smaller than the child. + if (header_.type_index < target_type_index) return false; + // Do a runtime lookup of type information +#if TVM_FFI_ALLOW_DYN_TYPE + // the function checks that the info exists + const TypeInfo* type_info = details::ObjectGetTypeInfo(header_.type_index); + return (type_info->type_depth > TargetType::_type_depth && + type_info->type_acenstors[TargetType::_type_depth] == target_type_index); +#else + return false; +#endif + } + // Information about the object static constexpr const char* _type_key = "runtime.Object"; @@ -96,11 +171,10 @@ class Object { static constexpr bool _type_final = false; static constexpr uint32_t _type_child_slots = 0; static constexpr bool _type_child_slots_can_overflow = true; - // NOTE: the following field is not type index of Object - // but was intended to be used by sub-classes as default value. - // The type index of Object is TypeIndex::kRoot + // NOTE: static type index field of the class static constexpr int32_t _type_index = TypeIndex::kTVMFFIObject; - + // the static type depth of the class + static constexpr int32_t _type_depth = 0; // The following functions are provided by macro // TVM_FFI_DECLARE_BASE_OBJECT_INFO and TVM_DECLARE_FINAL_OBJECT_INFO /*! @@ -110,7 +184,6 @@ class Object { static int32_t RuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } /*! * \brief Internal function to get or allocate a runtime index. - * \note */ static int32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } @@ -342,7 +415,13 @@ class ObjectRef { * \tparam ObjectType the target type, must be a subtype of Object */ template <typename ObjectType, typename = std::enable_if_t<std::is_base_of_v<Object, ObjectType>>> - inline const ObjectType* as() const; + const ObjectType* as() const { + if (data_ != nullptr && data_->IsInstance<ObjectType>()) { + return static_cast<ObjectType*>(data_.get()); + } else { + return nullptr; + } + } /*! \brief type indicate the container type. */ using ContainerType = Object; @@ -377,14 +456,30 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr); * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ -#define TVM_FFI_OBJECT_STATIC_CHECKS(TypeName, ParentType) \ +#define TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType) \ + static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ static_assert(!ParentType::_type_final, "ParentType marked as final"); \ static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ TypeName::_type_child_slots < ParentType::_type_child_slots, \ "Need to set _type_child_slots when parent specifies it."); \ static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ TypeName::_type_child_slots < ParentType::_type_child_slots, \ - "Need to set _type_child_slots when parent specifies it."); + "Need to set _type_child_slots when parent specifies it.") + +// If dynamic type is enabled, we still need to register the runtime type of parent +#if TVM_FFI_ALLOW_DYN_TYPE +#define TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType) \ + static int32_t _GetOrAllocRuntimeTypeIndex() { \ + static int32_t tindex = ::tvm::ffi::details::ObjectGetOrAllocTypeIndex( \ + TypeName::_type_key, TypeName::_type_index, TypeName::_type_depth, \ + TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow, \ + ParentType::_GetOrAllocRuntimeTypeIndex()); \ + return tindex; \ + } \ + static inline int32_t _register_type_index = _GetOrAllocRuntimeTypeIndex() +#else +#define TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType) +#endif /*! * \brief Helper macro to declare a object that comes with static type index. @@ -392,25 +487,38 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr); * \param ParentType The name of the ParentType */ #define TVM_FFI_DECLARE_STATIC_OBJECT_INFO(TypeName, ParentType) \ - TVM_FFI_OBJECT_STATIC_CHECKS(TypeName, ParentType) \ - static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } + TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType); \ + static int32_t RuntimeTypeIndex() { return TypeName::_type_index; }\ + TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType) /*! * \brief helper macro to declare a base object type that can be inherited. * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ -#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - static_assert(TVM_FFI_ALLOW_DYN_TYPE, \ - "Dynamic object depend on TVM_FFI_ALLOW_DYN_TYPE cd set to 1"); \ - TVM_FFI_OBJECT_STATIC_CHECKS(TypaName, ParentType) \ - static inline int32_t _type_index = _GetOrAllocRuntimeTypeIndex(); \ - static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \ - static int32_t _GetOrAllocRuntimeTypeIndex() { \ - return ::tvm::ffi::details::ObjectGetOrAllocTypeIndex( \ - TypeName::_type_key, -1, ParentType::_GetOrAllocRuntimeTypeIndex(), \ - TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \ - } +#define TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ + static_assert(TVM_FFI_ALLOW_DYN_TYPE, \ + "Dynamic object depend on TVM_FFI_ALLOW_DYN_TYPE cd set to 1"); \ + TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType); \ + static int32_t _GetOrAllocRuntimeTypeIndex() { \ + static int32_t tindex = ::tvm::ffi::details::ObjectGetOrAllocTypeIndex( \ + TypeName::_type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots, \ + TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ + return tindex; \ + } \ + static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); } \ + static inline int32_t _type_index = _GetOrAllocRuntimeTypeIndex() + +/*! + * \brief helper macro to declare type information in a final class. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ + static const constexpr int _type_child_slots = 0; \ + static const constexpr bool _type_final = true; \ + TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) + /* * \brief Define object reference methods. @@ -424,7 +532,7 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr); TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \ const ObjectName* get() const { return operator->(); } \ - using ContainerType = ObjectName; + using ContainerType = ObjectName /* * \brief Define object reference methods that is not nullable. @@ -439,9 +547,25 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr); const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \ const ObjectName* get() const { return operator->(); } \ static constexpr bool _type_is_nullable = false; \ - using ContainerType = ObjectName; + using ContainerType = ObjectName namespace details { + +// auxiliary class to enable static type info table at depth +template <int depth> +struct TypeInfoAtDepth : public TypeInfo { + /*! \brief extra type acenstors fields */ + int32_t _type_acenstors[depth]; + + TypeInfoAtDepth(const char* type_key, int32_t static_type_index) { + this->type_key = type_key; + this->type_key_hash = 0; + this->type_index = static_type_index; + this->type_depth = depth; + this->type_acenstors = _type_acenstors; + } +}; + /*! * \brief Namespace to internally manipulate object class. * \note These functions are only supposed to be used by internal @@ -469,37 +593,6 @@ struct ObjectInternal { } }; -// Code section that depends on dynamic components -#if TVM_FFI_ALLOW_DYN_TYPE -/*! - * \brief Get the type index using type key. - * - * When the function is first time called for a type, - * it will register the type to the type table in the runtime. - * If the static_tindex is TypeIndex::kDynamic, the function will - * allocate a runtime type index. - * Otherwise, we will populate the type table and return the static index. - * - * \param type_key the type key. - * \param static_tindex Static type index if any, can be -1, which means this is a dynamic index - * \param parent_tindex The index of the parent. - * \param type_child_slots Number of slots reserved for its children. - * \param type_child_slots_can_overflow Whether to allow child to overflow the slots. - * - * \return The allocated type index - */ -TVM_FFI_DLL int ObjectGetOrAllocTypeIndex(const char* type_key, int32_t static_tindex, - int32_t parent_tindex, int32_t type_child_slots, - bool type_child_slots_can_overflow); - -/*! - * \brief Check whether child type is derived from parent type. - * \param child_type_index The candidate child type index. - * \param parent_type_index The candidate parent type index. - * \return the Check result. - */ -TVM_FFI_DLL bool ObjectDerivedFrom(int32_t child_type_index, int32_t parent_type_index); -#endif // TVM_FFI_ALLOW_DYN_TYPE } // namespace details } // namespace ffi } // namespace tvm diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index ac8e04ebb9..f4e2ff3e30 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -23,6 +23,7 @@ #include <tvm/ffi/c_api.h> #include <tvm/ffi/error.h> +#include <memory> #include <string> #include <unordered_map> #include <utility> @@ -31,26 +32,6 @@ namespace tvm { namespace ffi { -/*! \brief Type information */ -struct TypeInfo { - /*! \brief The current index. */ - int32_t index{0}; - /*! \brief Index of the parent in the type hierarchy */ - int32_t parent_index{0}; - // NOTE: the indices in [index, index + num_reserved_slots) are - // reserved for the child-class of this type. - /*! \brief Total number of slots reserved for the type and its children. */ - int32_t num_slots{0}; - /*! \brief number of allocated child slots. */ - int32_t allocated_slots{0}; - /*! \brief Whether child can overflow. */ - bool child_slots_can_overflow{true}; - /*! \brief name of the type. */ - std::string name; - /*! \brief hash of the name */ - size_t name_hash{0}; -}; - /*! * \brief Type context that manages the type hierarchy information. * @@ -61,150 +42,196 @@ struct TypeInfo { * * Then the followup code will leverage the information */ -class TypeContext { +class TypeTable { public: - // NOTE: this is a relatively slow path for child checking - // Most types are already checked by the fast-path via reserved slot checking. - bool DerivedFrom(int32_t child_tindex, int32_t parent_tindex) { - // invariance: child's type index is always bigger than its parent. - if (child_tindex < parent_tindex) return false; - if (child_tindex == parent_tindex) return true; - TVM_FFI_ICHECK_LT(child_tindex, type_table_.size()); - while (child_tindex > parent_tindex) { - child_tindex = type_table_[child_tindex].parent_index; + /*! \brief Type information */ + struct Entry : public TypeInfo { + /*! \brief stored type key */ + std::string type_key_data; + /*! \brief acenstor information */ + std::vector<int32_t> type_acenstors_data; + // NOTE: the indices in [index, index + num_reserved_slots) are + // reserved for the child-class of this type. + /*! \brief Total number of slots reserved for the type and its children. */ + int32_t num_slots; + /*! \brief number of allocated child slots. */ + int32_t allocated_slots; + /*! \brief Whether child can overflow. */ + bool child_slots_can_overflow{true}; + + Entry(int32_t type_index, int32_t type_depth, std::string type_key, int32_t num_slots, + bool child_slots_can_overflow, const Entry* parent) { + // setup fields in the class + this->type_key_data = std::move(type_key); + this->num_slots = num_slots; + this->allocated_slots = 1; + this->child_slots_can_overflow = child_slots_can_overflow; + // set up type acenstors information + if (type_depth != 0) { + TVM_FFI_ICHECK_NOTNULL(parent); + TVM_FFI_ICHECK_EQ(type_depth, parent->type_depth + 1); + type_acenstors_data.resize(type_depth); + // copy over parent's type information + for (int32_t i = 0; i < parent->type_depth; ++i) { + type_acenstors_data[i] = parent->type_acenstors[i]; + } + // set last type information to be parent + type_acenstors_data[parent->type_depth] = parent->type_index; + } + // initialize type info: no change to type_key and type_acenstors fields + // after this line + this->type_index = type_index; + this->type_depth = type_depth; + this->type_key = this->type_key_data.c_str(); + this->type_key_hash = std::hash<std::string>()(this->type_key_data); + this->type_acenstors = type_acenstors_data.data(); } - return child_tindex == parent_tindex; - } + }; - int32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, int32_t static_tindex, - int32_t parent_tindex, int32_t num_child_slots, - bool child_slots_can_overflow) { - auto it = type_key2index_.find(skey); + int32_t GetOrAllocTypeIndex(std::string type_key, int32_t static_type_index, int32_t type_depth, + int32_t num_child_slots, bool child_slots_can_overflow, + int32_t parent_type_index) { + auto it = type_key2index_.find(type_key); if (it != type_key2index_.end()) { - return it->second; + return type_table_[it->second]->type_index; } - // try to allocate from parent's type table. - TVM_FFI_ICHECK_LT(parent_tindex, type_table_.size()) - << " skey=" << skey << ", static_index=" << static_tindex; - TypeInfo& pinfo = type_table_[parent_tindex]; - TVM_FFI_ICHECK_EQ(pinfo.index, parent_tindex); + // get parent's entry + Entry* parent = [&]() -> Entry* { + if (parent_type_index < 0) return nullptr; + // try to allocate from parent's type table. + TVM_FFI_ICHECK_LT(parent_type_index, type_table_.size()) + << " type_key=" << type_key << ", static_index=" << static_type_index; + return type_table_[parent_type_index].get(); + }(); + + // get allocated index + int32_t allocated_tindex = [&]() { + // Step 0: static allocation + if (static_type_index >= 0) { + TVM_FFI_ICHECK_LT(static_type_index, type_table_.size()); + TVM_FFI_ICHECK(type_table_[static_type_index] == nullptr) + << "Conflicting static index " << static_type_index << " between " + << type_table_[static_type_index]->type_key << " and " << type_key; + return static_type_index; + } + TVM_FFI_ICHECK_NOTNULL(parent); + int num_slots = num_child_slots + 1; + if (parent->allocated_slots + num_slots <= parent->num_slots) { + // allocate the slot from parent's reserved pool + int32_t allocated_tindex = parent->type_index + parent->allocated_slots; + // update parent's state + parent->allocated_slots += num_slots; + return allocated_tindex; + } + // Step 2: allocate from overflow + TVM_FFI_ICHECK(parent->child_slots_can_overflow) + << "Reach maximum number of sub-classes for " << parent->type_key; + // allocate new entries. + int32_t allocated_tindex = type_counter_; + type_counter_ += num_slots; + TVM_FFI_ICHECK_LE(type_table_.size(), type_counter_); + type_table_.reserve(type_counter_); + // resize type table + while (static_cast<int32_t>(type_table_.size()) < type_counter_) { + type_table_.emplace_back(nullptr); + } + return allocated_tindex; + }(); // if parent cannot overflow, then this class cannot. - if (!pinfo.child_slots_can_overflow) { + if (parent != nullptr && !(parent->child_slots_can_overflow)) { child_slots_can_overflow = false; } - // total number of slots include the type itself. - int32_t num_slots = num_child_slots + 1; - int32_t allocated_tindex; - - if (static_tindex > 0) { - // statically assigned type - allocated_tindex = static_tindex; - TVM_FFI_ICHECK_LT(static_tindex, type_table_.size()); - TVM_FFI_ICHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U) - << "Conflicting static index " << static_tindex << " between " - << type_table_[allocated_tindex].name << " and " << skey; - } else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) { - // allocate the slot from parent's reserved pool - allocated_tindex = parent_tindex + pinfo.allocated_slots; - // update parent's state - pinfo.allocated_slots += num_slots; - } else { - TVM_FFI_ICHECK(pinfo.child_slots_can_overflow) - << "Reach maximum number of sub-classes for " << pinfo.name; - // allocate new entries. - allocated_tindex = type_counter_; - type_counter_ += num_slots; - TVM_FFI_ICHECK_LE(type_table_.size(), type_counter_); - type_table_.resize(type_counter_, TypeInfo()); + + if (parent != nullptr) { + TVM_FFI_ICHECK_GT(allocated_tindex, parent->type_index); } - TVM_FFI_ICHECK_GT(allocated_tindex, parent_tindex); - // initialize the slot. - type_table_[allocated_tindex].index = allocated_tindex; - type_table_[allocated_tindex].parent_index = parent_tindex; - type_table_[allocated_tindex].num_slots = num_slots; - type_table_[allocated_tindex].allocated_slots = 1; - type_table_[allocated_tindex].child_slots_can_overflow = child_slots_can_overflow; - type_table_[allocated_tindex].name = skey; - type_table_[allocated_tindex].name_hash = std::hash<std::string>()(skey); + + type_table_[allocated_tindex] = + std::make_unique<Entry>(allocated_tindex, type_depth, type_key, num_child_slots + 1, + child_slots_can_overflow, parent); // update the key2index mapping. - type_key2index_[skey] = allocated_tindex; + type_key2index_[type_key] = allocated_tindex; return allocated_tindex; } - const std::string& TypeIndex2Key(int32_t tindex) { - if (tindex != 0) { - // always return the right type key for root - // for non-root type nodes, allocated slots should not equal 0 - TVM_FFI_ICHECK(tindex < static_cast<int32_t>(type_table_.size()) && - type_table_[tindex].allocated_slots != 0) - << "Unknown type index " << tindex; - } - return type_table_[tindex].name; - } - - size_t TypeIndex2KeyHash(int32_t tindex) { - TVM_FFI_ICHECK(tindex < static_cast<int32_t>(type_table_.size()) && - type_table_[tindex].allocated_slots != 0) - << "Unknown type index " << tindex; - return type_table_[tindex].name_hash; + int32_t TypeKey2Index(const std::string& type_key) { + auto it = type_key2index_.find(type_key); + TVM_FFI_ICHECK(it != type_key2index_.end()) << "Cannot find type " << type_key; + return it->second; } - int32_t TypeKey2Index(const std::string& skey) { - auto it = type_key2index_.find(skey); - TVM_FFI_ICHECK(it != type_key2index_.end()) << "Cannot find type " << skey; - return it->second; + const TypeInfo* GetTypeInfo(int32_t type_index) { + const TypeInfo* info = nullptr; + if (type_index >= 0 && static_cast<size_t>(type_index) < type_table_.size()) { + info = type_table_[type_index].get(); + } + TVM_FFI_ICHECK(info != nullptr) << "Cannot find type info for type_index=" << type_index; + return info; } void Dump(int min_children_count) { std::vector<int> num_children(type_table_.size(), 0); // reverse accumulation so we can get total counts in a bottom-up manner. for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) { - if (it->index != 0) { - num_children[it->parent_index] += num_children[it->index] + 1; + const Entry* ptr = it->get(); + if (ptr != nullptr && ptr->type_depth != 0) { + int parent_index = ptr->type_acenstors[ptr->type_depth - 1]; + num_children[parent_index] += num_children[ptr->type_index] + 1; } } - for (const auto& info : type_table_) { - if (info.index != 0 && num_children[info.index] >= min_children_count) { - std::cerr << '[' << info.index << "] " << info.name - << "\tparent=" << type_table_[info.parent_index].name - << "\tnum_child_slots=" << info.num_slots - 1 - << "\tnum_children=" << num_children[info.index] << std::endl; + for (const auto& ptr : type_table_) { + if (ptr != nullptr && num_children[ptr->type_index] >= min_children_count) { + std::cerr << '[' << ptr->type_index << "]\t" << ptr->type_key; + if (ptr->type_depth != 0) { + int32_t parent_index = ptr->type_acenstors[ptr->type_depth - 1]; + std::cerr << "\tparent=" << type_table_[parent_index]->type_key; + } else { + std::cerr << "\tparent=root"; + } + std::cerr << "\tnum_child_slots=" << ptr->num_slots - 1 + << "\tnum_children=" << num_children[ptr->type_index] << std::endl; } } } - static TypeContext* Global() { - static TypeContext inst; + static TypeTable* Global() { + static TypeTable inst; return &inst; } private: - TypeContext() { - type_table_.resize(TypeIndex::kTVMFFIDynObjectBegin, TypeInfo()); - type_table_[0].name = "runtime.Object"; + TypeTable() { + type_table_.reserve(TypeIndex::kTVMFFIDynObjectBegin); + for (int32_t i = 0; i < TypeIndex::kTVMFFIDynObjectBegin; ++i) { + type_table_.emplace_back(nullptr); + } + // initialize the entry for object + this->GetOrAllocTypeIndex(Object::_type_key, Object::_type_index, Object::_type_depth, + Object::_type_child_slots, Object::_type_child_slots_can_overflow, + -1); } int32_t type_counter_{TypeIndex::kTVMFFIDynObjectBegin}; - std::vector<TypeInfo> type_table_; + std::vector<std::unique_ptr<Entry>> type_table_; std::unordered_map<std::string, int32_t> type_key2index_; }; namespace details { -int32_t ObjectGetOrAllocTypeIndex(const char* type_key, int32_t static_tindex, - int32_t parent_tindex, int32_t type_child_slots, - bool type_child_slots_can_overflow) { - return tvm::ffi::TypeContext::Global()->GetOrAllocRuntimeTypeIndex( - type_key, static_tindex, parent_tindex, type_child_slots, type_child_slots_can_overflow != 0); +int32_t ObjectGetOrAllocTypeIndex(const char* type_key, int32_t static_type_index, + int32_t type_depth, int32_t num_child_slots, + bool child_slots_can_overflow, int32_t parent_index) { + return tvm::ffi::TypeTable::Global()->GetOrAllocTypeIndex(type_key, static_type_index, type_depth, + num_child_slots, + child_slots_can_overflow, parent_index); } -bool ObjectDerivedFrom(int32_t child_type_index, int32_t parent_type_index) { - return static_cast<int>( - tvm::ffi::TypeContext::Global()->DerivedFrom(child_type_index, parent_type_index)); +const TypeInfo* ObjectGetTypeInfo(int32_t type_index) { + return tvm::ffi::TypeTable::Global()->GetTypeInfo(type_index); } } // namespace details } // namespace ffi diff --git a/ffi/tests/example/test_error.cc b/ffi/tests/example/test_error.cc index 4f208b6999..015551c953 100644 --- a/ffi/tests/example/test_error.cc +++ b/ffi/tests/example/test_error.cc @@ -33,7 +33,6 @@ TEST(CheckError, Traceback) { TVM_FFI_ICHECK_GT(2, 3); } catch (const Error& error) { EXPECT_EQ(error->kind, "InternalError"); - std::cout << error.what(); std::string what = error.what(); EXPECT_NE(what.find("line"), std::string::npos); EXPECT_NE(what.find("2 > 3"), std::string::npos); diff --git a/ffi/tests/example/test_object.cc b/ffi/tests/example/test_object.cc index 959246fda0..4395fccf52 100644 --- a/ffi/tests/example/test_object.cc +++ b/ffi/tests/example/test_object.cc @@ -6,11 +6,55 @@ namespace { using namespace tvm::ffi; -class IntObj : public Object { +class NumberObj : public Object { + public: + // declare as one slot, with float as overflow + static constexpr uint32_t _type_child_slots = 1; + static constexpr const char* _type_key = "test.Number"; + TVM_FFI_DECLARE_BASE_OBJECT_INFO(NumberObj, Object); +}; + +class Number : public ObjectRef { + public: + TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Number, ObjectRef, NumberObj); +}; + +class IntObj : public NumberObj { public: int64_t value; IntObj(int64_t value) : value(value) {} + + static constexpr const char* _type_key = "test.Int"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(IntObj, NumberObj); +}; + +class Int : public Number { + public: + explicit Int(int64_t value) { + data_ = make_object<IntObj>(value); + } + + TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Int, Number, IntObj); +}; + +class FloatObj : public NumberObj { + public: + double value; + + FloatObj(double value) : value(value) {} + + static constexpr const char* _type_key = "test.Float"; + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(FloatObj, NumberObj); +}; + +class Float : public Number { + public: + explicit Float(double value) { + data_ = make_object<FloatObj>(value); + } + + TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(Float, Number, FloatObj); }; TEST(Object, RefCounter) { @@ -32,4 +76,51 @@ TEST(Object, RefCounter) { EXPECT_EQ(c->value, 11); } +TEST(Object, TypeInfo) { + const TypeInfo* info = tvm::ffi::details::ObjectGetTypeInfo(IntObj::RuntimeTypeIndex()); + EXPECT_TRUE(info != nullptr); + EXPECT_EQ(info->type_index, IntObj::RuntimeTypeIndex()); + EXPECT_EQ(info->type_depth, 2); + EXPECT_EQ(info->type_acenstors[0], Object::_type_index); + EXPECT_EQ(info->type_acenstors[1], NumberObj::_type_index); + EXPECT_GE(info->type_index, TypeIndex::kTVMFFIDynObjectBegin); +} + +TEST(Object, InstanceCheck) { + ObjectPtr<Object> a = make_object<IntObj>(11); + ObjectPtr<Object> b = make_object<FloatObj>(11); + + EXPECT_TRUE(a->IsInstance<Object>()); + EXPECT_TRUE(a->IsInstance<NumberObj>()); + EXPECT_TRUE(a->IsInstance<IntObj>()); + EXPECT_TRUE(!a->IsInstance<FloatObj>()); + + EXPECT_TRUE(a->IsInstance<Object>()); + EXPECT_TRUE(b->IsInstance<NumberObj>()); + EXPECT_TRUE(!b->IsInstance<IntObj>()); + EXPECT_TRUE(b->IsInstance<FloatObj>()); +} + +TEST(ObjectRef, as) { + ObjectRef a = Int(10); + ObjectRef b = Float(20); + // nullable object + ObjectRef c(nullptr); + + EXPECT_TRUE(a.as<IntObj>() != nullptr); + EXPECT_TRUE(a.as<FloatObj>() == nullptr); + EXPECT_TRUE(a.as<NumberObj>() != nullptr); + + EXPECT_TRUE(b.as<IntObj>() == nullptr); + EXPECT_TRUE(b.as<FloatObj>() != nullptr); + EXPECT_TRUE(b.as<NumberObj>() != nullptr); + + EXPECT_TRUE(c.as<IntObj>() == nullptr); + EXPECT_TRUE(c.as<FloatObj>() == nullptr); + EXPECT_TRUE(c.as<NumberObj>() == nullptr); + + EXPECT_EQ(a.as<IntObj>()->value, 10); + EXPECT_EQ(b.as<FloatObj>()->value, 20); +} + } // namespace
