This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 61249b46e790efbd0389ad1af0cc5148ac4d5e2a Author: tqchen <[email protected]> AuthorDate: Fri Dec 27 08:02:30 2024 +0800 [FFI] Initial reflection support --- ffi/cmake/Utils/CxxWarning.cmake | 2 +- ffi/include/tvm/ffi/c_api.h | 209 +++++++++++++++++++++++++-------- ffi/include/tvm/ffi/function_details.h | 6 +- ffi/include/tvm/ffi/object.h | 31 +---- ffi/include/tvm/ffi/reflection.h | 182 ++++++++++++++++++++++++++++ ffi/include/tvm/ffi/type_traits.h | 9 ++ ffi/src/ffi/object.cc | 49 +++++++- ffi/tests/cpp/test_reflection.cc | 49 ++++++++ ffi/tests/cpp/testing_object.h | 5 +- 9 files changed, 455 insertions(+), 87 deletions(-) diff --git a/ffi/cmake/Utils/CxxWarning.cmake b/ffi/cmake/Utils/CxxWarning.cmake index 50ee5b616d..efa9781e55 100644 --- a/ffi/cmake/Utils/CxxWarning.cmake +++ b/ffi/cmake/Utils/CxxWarning.cmake @@ -1,7 +1,7 @@ function(add_cxx_warning target_name) # GNU, Clang, or AppleClang if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang|AppleClang") - target_compile_options(${target_name} PRIVATE "-Werror" "-Wall" "-Wextra" "-Wpedantic") + target_compile_options(${target_name} PRIVATE "-Werror" "-Wall" "-Wextra" "-Wpedantic" "-Wno-unused-parameter") return() endif() # MSVC diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index c580410581..18c44e40da 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -51,27 +51,40 @@ enum TVMFFITypeIndex : int32_t { #else typedef enum { #endif - // [Section] On-stack POD Types: [0, kTVMFFIStaticObjectBegin) + // [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin) // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array, // which is not owned by TVMFFIAny. It is required that the following // invariant holds: // - `Any::type_index` is never `kTVMFFIRawStr` // - `AnyView::type_index` can be `kTVMFFIRawStr` + // + // NOTE: kTVMFFIAny is a root type of everything + // we include it so TypeIndex captures all possible runtime values. + // `kTVMFFIAny` code will never appear in Any::type_index. + // However, it may appear in field annotations during reflection. + // + kTVMFFIAny = -1, kTVMFFINone = 0, kTVMFFIInt = 1, - kTVMFFIFloat = 2, - kTVMFFIOpaquePtr = 3, - kTVMFFIDataType = 4, - kTVMFFIDevice = 5, - kTVMFFIRawStr = 6, + kTVMFFIBool = 2, + kTVMFFIFloat = 3, + kTVMFFIOpaquePtr = 4, + kTVMFFIDataType = 5, + kTVMFFIDevice = 6, + kTVMFFIDLTensorPtr = 7, + kTVMFFIRawStr = 8, // [Section] Static Boxed: [kTVMFFIStaticObjectBegin, kTVMFFIDynObjectBegin) + // roughly order in terms of their ptential dependencies kTVMFFIStaticObjectBegin = 64, kTVMFFIObject = 64, - kTVMFFIArray = 65, - kTVMFFIMap = 66, - kTVMFFIError = 67, - kTVMFFIFunc = 68, - kTVMFFIStr = 69, + kTVMFFIStr = 65, + kTVMFFIError = 66, + kTVMFFIFunc = 67, + kTVMFFIArray = 68, + kTVMFFIMap = 69, + kTVMFFIShapeTuple = 70, + kTVMFFINDArray = 71, + kTVMFFIRuntimeModule = 72, // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo) // kTVMFFIDynObject is used to indicate that the type index // is dynamic and needs to be looked up at runtime @@ -113,7 +126,10 @@ typedef struct TVMFFIAny { * \note The type index of Object and Any are shared in FFI. */ int32_t type_index; - /*! \brief length for on-stack Any object, such as small-string */ + /*! + * \brief length for on-stack Any object, such as small-string + * \note This field is reserved for future compact. + */ int32_t small_len; union { // 8 bytes int64_t v_int64; // integers @@ -134,6 +150,99 @@ typedef struct { const char* bytes; } TVMFFIByteArray; +/*! + * \brief Type that defines C-style safe call convention + * + * Safe call explicitly catches exception on function boundary. + * + * \param self The function handle + * \param num_args Number if input arguments + * \param args The input arguments to the call. + * \param result Store output result + * + * \return The call return 0 if call is successful. + * It returns non-zero value if there is an error. + * + * Possible return error of the API functions: + * * 0: success + * * -1: error happens, can be retrieved by TVMFFIGetLastError + * * -2: a frontend error occurred and recorded in the frontend. + * + * \note We decided to leverage TVMFFIGetLastError and TVMFFISetLastError + * for C function error propagation. This design choice, while + * introducing a dependency for TLS runtime, simplifies error + * propgation in chains of calls in compiler codegen. + * As we do not need to propagate error through argument but simply + * set them in the runtime environment. + */ +typedef int (*TVMFFISafeCallType)(void* self, int32_t num_args, const TVMFFIAny* args, + TVMFFIAny* result); + +/*! + * \brief Getter that can take address of a field and set the result. + * \param field The raw address of the field. + * \param result Stores the result. + */ +typedef int (*TVMFFIFieldGetter)(void* field, TVMFFIAny* result); + +/*! + * \brief Getter that can take address of a field and set to value. + * \param field The raw address of the field. + * \param value The value to set. + */ +typedef int (*TVMFFIFieldSetter)(void* field, const TVMFFIAny* value); + +/*! + * \brief Information support for optional object reflection. + */ +typedef struct { + /*! \brief The name of the field. */ + const char* name; + /*! + * \brief Records the static type kind of the field. + * + * Possible values: + * + * - TVMFFITypeIndex::kTVMFFIObject for general objects + * - The value is nullable when kTVMFFIObject is chosen + * - static object type kinds such as Map, Dict, String + * - POD type index + * - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info + * about the field. + * + * \note This information is helpful in designing serializer + * of the field. As it helps to narrow down the type of the + * object. It also helps to provide opportunities to enable + * short-cut access to the field. + */ + int32_t field_static_type_index; + /*! + * \brief Mark whether field is readonly. + */ + int32_t readonly; + /*! + * \brief Byte offset of the field. + */ + int64_t byte_offset; + /*! \brief The getter to access the field. */ + TVMFFIFieldGetter getter; + /*! \brief The setter to access the field. */ + TVMFFIFieldSetter setter; +} TVMFFIFieldInfo; + +/*! + * \brief Method information that can appear in reflection table. + */ +typedef struct { + /*! \brief The name of the field. */ + const char* name; + /*! + * \brief The method wrapped as Function + * \note The first argument to the method is always the self. + */ + TVMFFIObjectHandle method; +} TVMFFIMethodInfo; + /*! * \brief Runtime type information for object type checking. */ @@ -155,38 +264,27 @@ typedef struct { * hieracy stays as a tree */ const int32_t* type_acenstors; + /*! \brief number of reflection accessible fields. */ + int32_t num_fields; + /*! \brief number of reflection acccesible methods. */ + int32_t num_methods; + /*! \brief The reflection field information. */ + TVMFFIFieldInfo* fields; + /*! \brief The reflection method. */ + TVMFFIMethodInfo* methods; } TVMFFITypeInfo; //------------------------------------------------------------ // Section: User APIs to interact with the FFI //------------------------------------------------------------ /*! - * \brief Type that defines C-style safe call convention - * - * Safe call explicitly catches exception on function boundary. - * - * \param func The function handle - * \param num_args Number if input arguments - * \param args The input arguments to the call. - * \param result Store output result - * - * \return The call return 0 if call is successful. - * It returns non-zero value if there is an error. - * - * Possible return error of the API functions: - * * 0: success - * * -1: error happens, can be retrieved by TVMFFIGetLastError - * * -2: a frontend error occurred and recorded in the frontend. - * - * \note We decided to leverage TVMFFIGetLastError and TVMFFISetLastError - * for C function error propagation. This design choice, while - * introducing a dependency for TLS runtime, simplifies error - * propgation in chains of calls in compiler codegen. - * As we do not need to propagate error through argument but simply - * set them in the runtime environment. + * \brief Free an object handle by decreasing reference + * \param obj The object handle. + * \note Internally we decrease the reference counter of the object. + * The object will be freed when every reference to the object are removed. + * \return 0 when success, nonzero when failure happens */ -typedef int (*TVMFFISafeCallType)(void* func, int32_t num_args, const TVMFFIAny* args, - TVMFFIAny* result); +TVM_FFI_DLL int TVMFFIObjectFree(TVMFFIObjectHandle obj); /*! * \brief Create a FFIFunc by passing in callbacks from C callback. @@ -223,15 +321,6 @@ TVM_FFI_DLL int TVMFFIFuncSetGlobal(const char* name, TVMFFIObjectHandle f, int */ TVM_FFI_DLL int TVMFFIFuncGetGlobal(const char* name, TVMFFIObjectHandle* out); -/*! - * \brief Free an object handle by decreasing reference - * \param obj The object handle. - * \note Internally we decrease the reference counter of the object. - * The object will be freed when every reference to the object are removed. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIObjectFree(TVMFFIObjectHandle obj); - /*! * \brief Move the last error from the environment to result. * @@ -249,6 +338,30 @@ TVM_FFI_DLL void TVMFFIMoveFromLastError(TVMFFIAny* result); */ TVM_FFI_DLL void TVMFFISetLastError(const TVMFFIAny* error_view); +/*! + * \brief Convert type key to type index. + * \param type_key The key of the type. + * \param out_tindex the corresponding type index. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFITypeKey2Index(const char* type_key, int32_t* out_tindex); + +/*! + * \brief Register type field information for rutnime reflection. + * \param type_index The type index + * \param info The field info to be registered. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIRegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info); + +/*! + * \brief Register type method information for rutnime reflection. + * \param type_index The type index + * \param info The method info to be registered. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIRegisterTypeMethod(int32_t type_index, const TVMFFIMethodInfo* info); + //------------------------------------------------------------ // Section: Backend noexcept functions for internal use // @@ -292,13 +405,13 @@ TVM_FFI_DLL int32_t TVMFFIGetOrAllocTypeIndex(const char* type_key, int32_t stat int32_t type_depth, int32_t num_child_slots, int32_t child_slots_can_overflow, int32_t parent_type_index); + /*! * \brief Get dynamic type info by type index. * * \param type_index The type index * \param result The output type information - * - * \return 0 when success, nonzero when failure happens + * \return The type info */ TVM_FFI_DLL const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index); diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h index 56b7ca7df3..0f6701e509 100644 --- a/ffi/include/tvm/ffi/function_details.h +++ b/ffi/include/tvm/ffi/function_details.h @@ -161,7 +161,7 @@ class MovableArgValueWithContext { template <typename Type> TVM_FFI_INLINE operator Type() { using TypeWithoutCR = std::remove_const_t<std::remove_reference_t<Type>>; - std::optional<TypeWithoutCR> opt = TryAs<TypeWithoutCR>(args_[arg_index_]); + std::optional<TypeWithoutCR> opt = TryAs<TypeWithoutCR>(args_[arg_index_]); if (opt.has_value()) { return std::move(*opt); } @@ -210,8 +210,8 @@ struct unpack_call_dispatcher<R, 0, index, F> { template <int index, typename F> struct unpack_call_dispatcher<void, 0, index, F> { template <typename... Args> - TVM_FFI_INLINE static void run(const std::string* optional_name, FGetFuncSignature f_sig, - const F& f, int32_t num_args, const AnyView* args, Any* rv, + TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature , + const F& , int32_t , const AnyView* , Any* , Args&&... unpacked_args) { f(std::forward<Args>(unpacked_args)...); } diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 6a202ba1f5..70b68f6fc6 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -39,7 +39,7 @@ using TypeInfo = TVMFFITypeInfo; namespace details { // Helper to perform // unsafe operations related to object -struct ObjectUnsafe; +class ObjectUnsafe; /*! * Check if the type_index is an instance of TargetObjectType. @@ -445,32 +445,6 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr); static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \ TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType) -/*! - * \brief helper macro to declare a base object type that can be inherited. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ - TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType); \ - static int32_t _GetOrAllocRuntimeTypeIndex() { \ - static int32_t tindex = TVMFFIGetOrAllocTypeIndex( \ - TypeName::_type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots, \ - TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ - return tindex; \ - } \ - static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); } \ - static inline int32_t _type_index = _GetOrAllocRuntimeTypeIndex() - -/*! - * \brief helper macro to declare type information in a final class. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ - static const constexpr int _type_child_slots = 0; \ - static const constexpr bool _type_final = true; \ - TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) - /* * \brief Define object reference methods. * \param TypeName The object type name @@ -539,7 +513,8 @@ TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) { * \note These functions are only supposed to be used by internal * implementations and not external users of the tvm::ffi */ -struct ObjectUnsafe { +class ObjectUnsafe { + public: // NOTE: get ffi header from an object static TVM_FFI_INLINE TVMFFIObject* GetHeader(const Object* src) { return const_cast<TVMFFIObject*>(&(src->header_)); diff --git a/ffi/include/tvm/ffi/reflection.h b/ffi/include/tvm/ffi/reflection.h new file mode 100644 index 0000000000..51573d1bc2 --- /dev/null +++ b/ffi/include/tvm/ffi/reflection.h @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/reflection.h + * \brief Base reflection support to access object fields. + */ +#ifndef TVM_FFI_REFLECTION_H_ +#define TVM_FFI_REFLECTION_H_ + +#include <tvm/ffi/any.h> +#include <tvm/ffi/c_api.h> +#include <tvm/ffi/function.h> +#include <tvm/ffi/type_traits.h> + +namespace tvm { +namespace ffi { +namespace details { + +template <typename T, typename = void> +struct Type2FieldStaticTypeIndex { + static constexpr int32_t value = TypeIndex::kTVMFFIAny; +}; + +template <typename T> +struct Type2FieldStaticTypeIndex<T, std::enable_if_t<TypeTraits<T>::enabled>> { + static constexpr int32_t value = TypeTraits<T>::field_static_type_index; +}; + +/*! + * \brief Get the byte offset of a class member field. + * + * \tparam The original class. + * \tparam T the field type. + * + * \param field_ptr A class member pointer + * \returns The byteoffset + */ +template <typename Class, typename T> +inline int64_t GetFieldByteOffset(T Class::*field_ptr) { + return reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr)); +} + +class ReflectionDef { + public: + explicit ReflectionDef(int32_t type_index) : type_index_(type_index) {} + + template <typename Class, typename T> + ReflectionDef& def_readonly(const char* name, T Class::*field_ptr) { + RegisterField(name, field_ptr, true); + return *this; + } + + template <typename Class, typename T> + ReflectionDef& def_readwrite(const char* name, T Class::*field_ptr) { + RegisterField(name, field_ptr, false); + return *this; + } + + operator int32_t() const { return type_index_; } + + private: + template <typename Class, typename T> + void RegisterField(const char* name, T Class::*field_ptr, bool readonly) { + TVMFFIFieldInfo info; + info.name = name; + info.field_static_type_index = Type2FieldStaticTypeIndex<T>::value; + // store byte offset and setter, getter + // so the same setter can be reused for all the same type + info.byte_offset = GetFieldByteOffset<Class, T>(field_ptr); + info.readonly = readonly; + info.getter = FieldGetter<T>; + info.setter = FieldSetter<T>; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIRegisterTypeField(type_index_, &info)); + } + + template <typename T> + static int FieldGetter(void* field, TVMFFIAny* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + Any(*reinterpret_cast<T*>(field)).MoveToTVMFFIAny(result); + TVM_FFI_SAFE_CALL_END(); + } + + template <typename T> + static int FieldSetter(void* field, const TVMFFIAny* value) { + TVM_FFI_SAFE_CALL_BEGIN(); + *reinterpret_cast<T*>(field) = AnyView::CopyFromTVMFFIAny(*value); + TVM_FFI_SAFE_CALL_END(); + } + + int32_t type_index_; +}; + +/*! + * \brief helper function to get reflection field info by type key and field name + */ +inline const TVMFFIFieldInfo* GetReflectionFieldInfo(const char* type_key, const char* field_name) { + int32_t type_index; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKey2Index(type_key, &type_index)); + const TypeInfo* info = TVMFFIGetTypeInfo(type_index); + for (int32_t i = 0; i < info->num_fields; ++i) { + if (std::strcmp(info->fields[i].name, field_name) == 0) { + return &(info->fields[i]); + } + } + TVM_FFI_THROW(RuntimeError) << "Cannot find field " << field_name << " in " << type_key; +} + +/*! + * \brief helper wrapper class to obtain a getter. + */ +class ReflectionFieldGetter { + public: + explicit ReflectionFieldGetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) { + } + + Any operator()(const Object* obj_ptr) const { + Any result; + const void* addr = reinterpret_cast<const char*>(obj_ptr) + field_info_->byte_offset; + TVM_FFI_CHECK_SAFE_CALL(field_info_->getter(const_cast<void*>(addr), reinterpret_cast<TVMFFIAny*>(&result))); + return result; + } + + Any operator()(const ObjectPtr<Object>& obj_ptr) const { + return operator()(obj_ptr.get()); + } + + Any operator()(const ObjectRef& obj) const { + return operator()(obj.get()); + } + + private: + const TVMFFIFieldInfo* field_info_; +}; + +} // namespace details + +/*! + * \brief helper macro to declare a base object type that can be inherited. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \ + TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType); \ + static int32_t _GetOrAllocRuntimeTypeIndex() { \ + static int32_t tindex = TVMFFIGetOrAllocTypeIndex( \ + TypeName::_type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots, \ + TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ + return tindex; \ + } \ + static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); } \ + static inline int32_t _type_index = \ + ::tvm::ffi::details::ReflectionDef(_GetOrAllocRuntimeTypeIndex()) + +/*! + * \brief helper macro to declare type information in a final class. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ + static const constexpr int _type_child_slots = 0; \ + static const constexpr bool _type_final = true; \ + TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_REFLECTION_H_ diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index a46d04ad83..cf88eb7319 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -107,6 +107,8 @@ struct TypeTraitsBase { // None template <> struct TypeTraits<std::nullptr_t> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFINone; + static TVM_FFI_INLINE void CopyToAnyView(const std::nullptr_t&, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFINone; // invariant: the pointer field also equals nullptr @@ -142,6 +144,8 @@ struct TypeTraits<std::nullptr_t> : public TypeTraitsBase { // Integer POD values template <typename Int> struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt; + static TVM_FFI_INLINE void CopyToAnyView(const Int& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIInt; result->v_int64 = static_cast<int64_t>(src); @@ -171,6 +175,8 @@ struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> : public TypeT template <typename Float> struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFloat; + static TVM_FFI_INLINE void CopyToAnyView(const Float& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIFloat; result->v_float64 = static_cast<double>(src); @@ -205,6 +211,8 @@ struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> // void* template <> struct TypeTraits<void*> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIOpaquePtr; + static TVM_FFI_INLINE void CopyToAnyView(void* src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIOpaquePtr; // maintain padding zero in 32bit platform @@ -241,6 +249,7 @@ template <typename TObjRef> struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef> && use_default_type_traits_v<TObjRef>>> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIObject; using ContainerType = typename TObjRef::ContainerType; static TVM_FFI_INLINE void CopyToAnyView(const TObjRef& src, TVMFFIAny* result) { diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index 8941753df8..a0fe306e78 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -52,6 +52,8 @@ class TypeTable { std::string type_key_data; /*! \brief acenstor information */ std::vector<int32_t> type_acenstors_data; + /*! \brief type fields informaton */ + std::vector<TVMFFIFieldInfo> type_fields_data; // NOTE: the indices in [index, index + num_reserved_slots) are // reserved for the child-class of this type. /*! \brief Total number of slots reserved for the type and its children. */ @@ -87,6 +89,11 @@ class TypeTable { this->type_key = this->type_key_data.c_str(); this->type_key_hash = std::hash<std::string>()(this->type_key_data); this->type_acenstors = type_acenstors_data.data(); + // initialize the reflection information + this->num_fields = 0; + this->num_methods = 0; + this->fields = nullptr; + this->methods = nullptr; } }; @@ -165,13 +172,23 @@ class TypeTable { return it->second; } - const TypeInfo* GetTypeInfo(int32_t type_index) { - const TypeInfo* info = nullptr; + Entry* GetTypeEntry(int32_t type_index) { + Entry* entry = nullptr; if (type_index >= 0 && static_cast<size_t>(type_index) < type_table_.size()) { - info = type_table_[type_index].get(); + entry = type_table_[type_index].get(); } - TVM_FFI_ICHECK(info != nullptr) << "Cannot find type info for type_index=" << type_index; - return info; + TVM_FFI_ICHECK(entry != nullptr) << "Cannot find type info for type_index=" << type_index; + return entry; + } + + void RegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) { + Entry* entry = GetTypeEntry(type_index); + TVMFFIFieldInfo field_data = *info; + field_data.name = this->CopyString(info->name); + entry->type_fields_data.push_back(field_data); + // refresh ptr as the data can change + entry->fields = entry->type_fields_data.data(); + entry->num_fields = static_cast<int32_t>(entry->type_fields_data.size()); } void Dump(int min_children_count) { @@ -217,9 +234,17 @@ class TypeTable { -1); } + const char* CopyString(const char* c_str) { + std::unique_ptr<std::string> val = std::make_unique<std::string>(c_str); + const char* c_val = val->c_str(); + string_pool_.emplace_back(std::move(val)); + return c_val; + } + int32_t type_counter_{TypeIndex::kTVMFFIDynObjectBegin}; std::vector<std::unique_ptr<Entry>> type_table_; std::unordered_map<std::string, int32_t> type_key2index_; + std::vector<std::unique_ptr<std::string>> string_pool_; }; } // namespace ffi } // namespace tvm @@ -230,6 +255,18 @@ int TVMFFIObjectFree(TVMFFIObjectHandle handle) { TVM_FFI_SAFE_CALL_END(); } +int TVMFFITypeKey2Index(const char* type_key, int32_t* out_tindex) { + TVM_FFI_SAFE_CALL_BEGIN(); + out_tindex[0] = tvm::ffi::TypeTable::Global()->TypeKey2Index(type_key); + TVM_FFI_SAFE_CALL_END(); +} + +int TVMFFIRegisterTypeField(int32_t type_index, const TVMFFIFieldInfo* info) { + TVM_FFI_SAFE_CALL_BEGIN(); + tvm::ffi::TypeTable::Global()->RegisterTypeField(type_index, info); + TVM_FFI_SAFE_CALL_END(); +} + int32_t TVMFFIGetOrAllocTypeIndex(const char* type_key, int32_t static_type_index, int32_t type_depth, int32_t num_child_slots, int32_t child_slots_can_overflow, int32_t parent_type_index) { @@ -242,6 +279,6 @@ int32_t TVMFFIGetOrAllocTypeIndex(const char* type_key, int32_t static_type_inde const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index) { TVM_FFI_LOG_EXCEPTION_CALL_BEGIN(); - return tvm::ffi::TypeTable::Global()->GetTypeInfo(type_index); + return tvm::ffi::TypeTable::Global()->GetTypeEntry(type_index); TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIGetTypeInfo); } diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc new file mode 100644 index 0000000000..c7901a4ca3 --- /dev/null +++ b/ffi/tests/cpp/test_reflection.cc @@ -0,0 +1,49 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <gtest/gtest.h> +#include <tvm/ffi/object.h> +#include <tvm/ffi/reflection.h> +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +struct A { + ObjectRef obj; + int32_t x; + int32_t y; +}; + +TEST(Reflection, GetFieldByteOffset) { + EXPECT_EQ(details::GetFieldByteOffset(&A::x), 8); + EXPECT_EQ(details::GetFieldByteOffset(&A::y), 12); +} + + +TEST(Reflection, FieldGetter) { + ObjectRef a = TInt(10); + details::ReflectionFieldGetter getter( + details::GetReflectionFieldInfo("test.Int", "value") + ); + EXPECT_EQ(getter(a).operator int(), 10); +} +} // namespace diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index e660b2751e..a4d6b1353f 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -19,8 +19,10 @@ #ifndef TVM_FFI_TESTING_OBJECT_H_ #define TVM_FFI_TESTING_OBJECT_H_ + #include <tvm/ffi/memory.h> #include <tvm/ffi/object.h> +#include <tvm/ffi/reflection.h> namespace tvm { namespace ffi { @@ -46,7 +48,8 @@ class TIntObj : public TNumberObj { TIntObj(int64_t value) : value(value) {} static constexpr const char* _type_key = "test.Int"; - TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj); + + TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj).def_readonly("value", &TIntObj::value); }; class TInt : public TNumber {
