This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 45f4ff950f1247734ffb60202d71310743bc752f Author: tqchen <[email protected]> AuthorDate: Sun Aug 11 10:38:15 2024 -0400 [FFI] ObjectRef based Error --- ffi/include/tvm/ffi/c_ffi_abi.h | 46 ++++---- ffi/include/tvm/ffi/c_ffi_api.h | 60 +++++++++++ ffi/include/tvm/ffi/error.h | 123 +++++++++++++++++++++ ffi/include/tvm/ffi/internal_utils.h | 28 +++++ ffi/include/tvm/ffi/memory.h | 10 +- ffi/include/tvm/ffi/object.h | 203 ++++++++++++++++++++++++++++++++--- ffi/src/ffi/registry.cc | 1 - ffi/tests/example/test_error.cc | 28 +++++ 8 files changed, 460 insertions(+), 39 deletions(-) diff --git a/ffi/include/tvm/ffi/c_ffi_abi.h b/ffi/include/tvm/ffi/c_ffi_abi.h index f1a1ae78da..34d25bf17b 100644 --- a/ffi/include/tvm/ffi/c_ffi_abi.h +++ b/ffi/include/tvm/ffi/c_ffi_abi.h @@ -19,7 +19,13 @@ /* * \file tvm/ffi/c_ffi_abi.h - * \brief This file defines + * \brief This file defines the ABI convention of the FFI convention + * + * \note This file only include data structures that can be used in + * a header only way. The APIs are defined in c_ffi_api.h + * and requires linking to tvm_ffi library. + * + * Only use the APIs when TVM_FFI_ALLOW_DYN_TYPE is set to true */ #ifndef TVM_FFI_C_FFI_ABI_H_ #define TVM_FFI_C_FFI_ABI_H_ @@ -27,23 +33,11 @@ #include <dlpack/dlpack.h> #include <stdint.h> -#if !defined(TVM_FFI_DLL) && defined(__EMSCRIPTEN__) -#include <emscripten/emscripten.h> -#define TVM_FFI_API EMSCRIPTEN_KEEPALIVE -#endif -#if !defined(TVM_FFI_DLL) && defined(_MSC_VER) -#ifdef TVM_FFI_EXPORTS -#define TVM_FFI_DLL __declspec(dllexport) -#else -#define TVM_FFI_DLL __declspec(dllimport) -#endif -#endif -#ifndef TVM_FFI_DLL -#define TVM_FFI_DLL __attribute__((visibility("default"))) -#endif - +/*! + * \brief Macro defines whether we enable dynamic runtime features. + */ #ifndef TVM_FFI_ALLOW_DYN_TYPE -#define TVM_FFI_ALLOW_DYN_TYPE 0 +#define TVM_FFI_ALLOW_DYN_TYPE 1 #endif #ifdef __cplusplus @@ -100,7 +94,7 @@ typedef struct TVMFFIObject { /*! \brief Reference counter of the object. */ int32_t ref_counter; /*! \brief Deleter to be invoked when reference counter goes to zero. */ - void (*deleter)(struct TVMFFIObject* self); + void (*deleter)(void* self); } TVMFFIObject; /*! @@ -130,8 +124,22 @@ typedef struct TVMFFIAny { }; } TVMFFIAny; +/*! \brief Safe byte array */ +typedef struct { + int64_t num_bytes; + const char* bytes; +} TVMFFIByteArray; + +/*! \brief The error type. */ +typedef struct { + /*! \brief header */ + TVMFFIObject header_; + + +} TVMFFIError; + + #ifdef __cplusplus } // TVM_FFI_EXTERN_C #endif - #endif // TVM_FFI_C_FFI_ABI_H_ diff --git a/ffi/include/tvm/ffi/c_ffi_api.h b/ffi/include/tvm/ffi/c_ffi_api.h new file mode 100644 index 0000000000..dda54f0032 --- /dev/null +++ b/ffi/include/tvm/ffi/c_ffi_api.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvm/ffi/c_ffi_abi.h + * \brief This file defines the ABI convention of the FFI convention + * + * Including global calling conventions + */ +#ifndef TVM_FFI_C_FFI_ABI_H_ +#define TVM_FFI_C_FFI_ABI_H_ + +#include <tvm/ffi/c_ffi_abi.h> + +#if !defined(TVM_FFI_DLL) && defined(__EMSCRIPTEN__) +#include <emscripten/emscripten.h> +#define TVM_FFI_API EMSCRIPTEN_KEEPALIVE +#endif +#if !defined(TVM_FFI_DLL) && defined(_MSC_VER) +#ifdef TVM_FFI_EXPORTS +#define TVM_FFI_DLL __declspec(dllexport) +#else +#define TVM_FFI_DLL __declspec(dllimport) +#endif +#endif +#ifndef TVM_FFI_DLL +#define TVM_FFI_DLL __attribute__((visibility("default"))) +#endif + +#ifdef __cplusplus +static_assert( + TVM_FFI_ALLOW_DYN_TYPE, + "Only include c_ffi_abi when TVM_FFI_ALLOW_DYN_TYPE is set to true" +); +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __cplusplus +} // TVM_FFI_EXTERN_C +#endif +#endif // TVM_FFI_C_FFI_ABI_H_ diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h new file mode 100644 index 0000000000..7b40f00bf9 --- /dev/null +++ b/ffi/include/tvm/ffi/error.h @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvm/ffi/error.h + * \brief Error handling component. + */ +#ifndef TVM_FFI_ERROR_H_ +#define TVM_FFI_ERROR_H_ + +#include <tvm/ffi/object.h> +#include <tvm/ffi/memory.h> + +#include <string> +#include <sstream> + +namespace tvm { +namespace ffi { + +/*! + * \brief Error object class. + */ +class ErrorObj: public Object { + public: + /*! \brief The error kind */ + std::string kind; + /*! \brief Message the error message. */ + std::string message; + /*! \brief Backtrace, follows python convention(most recent last). */ + std::string backtrace; + /*! \brief Full message in what_str */ + std::string what_str; + + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIError; + static constexpr const char* _type_key = "object.Error"; + + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ErrorObj, Object); +}; + +/*! + * \brief Managed reference to ErrorObj + * \sa Error Object + */ +class Error : + public ObjectRef, + public std::exception { + public: + Error(std::string kind, std::string message, std::string backtrace) { + std::ostringstream what; + what << "Traceback (most recent call last):\n" << backtrace << kind << ": " << message << '\n'; + ObjectPtr<ErrorObj> n = make_object<ErrorObj>(); + n->kind = std::move(kind); + n->message = std::move(message); + n->backtrace = std::move(backtrace); + n->what_str = what.str(); + data_ = std::move(n); + } + + const char* what() const noexcept(true) override { + return get()->what_str.c_str(); + } + + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Error, ObjectRef, ErrorObj) +}; + +namespace details { + +class ErrorBuilder { + public: + explicit ErrorBuilder(const char* kind, const char* filename, const char* func, int32_t lineno) + : kind_(kind) { + std::ostringstream backtrace; + // python style backtrace + backtrace << " " << filename << ", line " << lineno << ", in " << func << '\n'; + backtrace_ = backtrace.str(); + } + +// MSVC disable warning in error builder as it is exepected +#ifdef _MSC_VER +#pragma disagnostic push +#pragma warning(disable : 4722) +#endif + [[noreturn]] ~ErrorBuilder() noexcept(false) { + throw ::tvm::ffi::Error(std::move(kind_), message_.str(), std::move(backtrace_)); + } +#ifdef _MSC_VER +#pragma disagnostic pop +#endif + + std::ostringstream &Get() { return message_; } + +protected: + std::string kind_; + std::ostringstream message_; + std::string backtrace_; +}; +} // namespace details + +/*! + * \brief Helper macro to throw an error with backtrace and message + */ +#define TVM_FFI_THROW(ErrorKind) \ + ::tvm::ffi::details::ErrorBuilder(#ErrorKind, __FILE__, TVM_FFI_FUNC_SIG, __LINE__).Get() + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_ERROR_H_ diff --git a/ffi/include/tvm/ffi/internal_utils.h b/ffi/include/tvm/ffi/internal_utils.h index 53d08c364c..5eca030028 100644 --- a/ffi/include/tvm/ffi/internal_utils.h +++ b/ffi/include/tvm/ffi/internal_utils.h @@ -39,6 +39,34 @@ #define TVM_FFI_UNREACHABLE() __builtin_unreachable() #endif +/*! \brief helper macro to suppress unused warning */ +#if defined(__GNUC__) +#define TVM_FFI_ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define TVM_FFI_ATTRIBUTE_UNUSED +#endif + +#define TVM_FFI_STR_CONCAT_(__x, __y) __x##__y +#define TVM_FFI_STR_CONCAT(__x, __y) TVM_FFI_STR_CONCAT_(__x, __y) + +#if defined(__GNUC__) || defined(__clang__) +#define TVM_FFI_FUNC_SIG __PRETTY_FUNCTION__ +#elif defined(_MSC_VER) +#define TVM_FFI_FUNC_SIG __FUNCSIG__ +#else +#define TVM_FFI_FUNC_SIG __func__ +#endif + +/* + * \brief Define the default copy/move constructor and assign operator + * \param TypeName The class typename. + */ +#define TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + TypeName(const TypeName& other) = default; \ + TypeName(TypeName&& other) = default; \ + TypeName& operator=(const TypeName& other) = default; \ + TypeName& operator=(TypeName&& other) = default; + namespace tvm { namespace ffi { diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h index cb7067502e..ca3021ff78 100644 --- a/ffi/include/tvm/ffi/memory.h +++ b/ffi/include/tvm/ffi/memory.h @@ -33,7 +33,7 @@ namespace tvm { namespace ffi { /*! \brief Deleter function for obeject */ -typedef void (*FObjectDeleter)(TVMFFIObject* obj); +typedef void (*FObjectDeleter)(void* obj); /*! * \brief Allocate an object using default allocator. @@ -74,11 +74,11 @@ class ObjAllocatorBase { using Handler = typename Derived::template Handler<T>; static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object"); T* ptr = Handler::New(static_cast<Derived*>(this), std::forward<Args>(args)...); - TVMFFIObject* ffi_ptr = details::ObjectInternal::StaticCast<TVMFFIObject*>(ptr); + TVMFFIObject* ffi_ptr = details::ObjectInternal::GetHeader(ptr); // NOTE: ref_counter is initialized in object constructor ffi_ptr->type_index = T::RuntimeTypeIndex(); ffi_ptr->deleter = Handler::Deleter(); - return details::ObjectInternal::ObjectPtr<T>(ptr); + return details::ObjectInternal::ObjectPtrFromUnowned<T>(ptr); } /*! @@ -132,11 +132,11 @@ class SimpleObjAllocator : public ObjAllocatorBase<SimpleObjAllocator> { static FObjectDeleter Deleter() { return Deleter_; } private: - static void Deleter_(TVMFFIObject* objptr) { + static void Deleter_(void* objptr) { // NOTE: this is important to cast back to T* // because objptr and tptr may not be the same // depending on how sub-class allocates the space. - T* tptr = details::ObjectInternal::StaticCast<T*>(objptr); + T* tptr = static_cast<T*>(objptr); // It is important to do tptr->T::~T(), // so that we explicitly call the specific destructor // instead of tptr->~T(), which could mean the intention diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 6b973f943c..037b08bbca 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -75,11 +75,15 @@ struct ObjectInternal; * New objects can be created using make_object function. * Which will automatically populate the type_index and deleter of the object. */ -class Object : private TVMFFIObject { +class Object { + private: + /*! \brief header field that is the common prefix of all objects */ + TVMFFIObject header_; + public: Object() { - TVMFFIObject::ref_counter = 0; - TVMFFIObject::deleter = nullptr; + header_.ref_counter = 0; + header_.deleter = nullptr; } // Information about the object @@ -109,13 +113,13 @@ class Object : private TVMFFIObject { private: /*! \brief increase reference count */ - void IncRef() { details::AtomicIncrementRelaxed(&(this->ref_counter)); } + void IncRef() { details::AtomicIncrementRelaxed(&(header_.ref_counter)); } /*! \brief decrease reference count and delete the object */ void DecRef() { - if (details::AtomicDecrementRelAcq(&(this->ref_counter)) == 1) { - if (this->deleter != nullptr) { - this->deleter(this); + if (details::AtomicDecrementRelAcq(&(header_.ref_counter)) == 1) { + if (header_.deleter != nullptr) { + header_.deleter(this); } } } @@ -124,7 +128,7 @@ class Object : private TVMFFIObject { * \return The usage count of the cell. * \note We use stl style naming to be consistent with known API in shared_ptr. */ - int32_t use_count() const { return details::AtomicLoadRelaxed(&(this->ref_counter)); } + int32_t use_count() const { return details::AtomicLoadRelaxed(&(header_.ref_counter)); } // friend classes template <typename> @@ -284,6 +288,178 @@ class ObjectPtr { friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr); }; + +// Forward declaration, to prevent circular includes. +template <typename T> +class Optional; + +/*! \brief Base class of all object reference */ +class ObjectRef { + public: + /*! \brief default constructor */ + ObjectRef() = default; + /*! \brief Constructor from existing object ptr */ + explicit ObjectRef(ObjectPtr<Object> data) : data_(data) {} + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool same_as(const ObjectRef& other) const { return data_ == other.data_; } + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool operator==(const ObjectRef& other) const { return data_ == other.data_; } + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool operator!=(const ObjectRef& other) const { return data_ != other.data_; } + /*! + * \brief Comparator + * \param other Another object ref by address. + * \return the compare result. + */ + bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } + /*! + * \return whether the object is defined(not null). + */ + bool defined() const { return data_ != nullptr; } + /*! \return the internal object pointer */ + const Object* get() const { return data_.get(); } + /*! \return the internal object pointer */ + const Object* operator->() const { return get(); } + /*! \return whether the reference is unique */ + bool unique() const { return data_.unique(); } + /*! \return The use count of the ptr, for debug purposes */ + int use_count() const { return data_.use_count(); } + + /*! + * \brief Try to downcast the internal Object to a + * raw pointer of a corresponding type. + * + * The function will return a nullptr if the cast failed. + * + * if (const AddNode *ptr = node_ref.as<AddNode>()) { + * // This is an add node + * } + * + * \tparam ObjectType the target type, must be a subtype of Object + */ + template <typename ObjectType, typename = std::enable_if_t<std::is_base_of_v<Object, ObjectType>>> + inline const ObjectType* as() const; + + /*! \brief type indicate the container type. */ + using ContainerType = Object; + // Default type properties for the reference class. + static constexpr bool _type_is_nullable = true; + + protected: + /*! \brief Internal pointer that backs the reference. */ + ObjectPtr<Object> data_; + /*! \return return a mutable internal ptr, can be used by sub-classes. */ + Object* get_mutable() const { return data_.get(); } + // friend classes. + friend struct ObjectPtrHash; + friend class tvm::ffi::details::ObjectInternal; + template <typename SubRef, typename BaseRef> + friend SubRef Downcast(BaseRef ref); +}; + +/*! + * \brief Get an object ptr type from a raw object ptr. + * + * \param ptr The object pointer + * \tparam BaseType The reference type + * \tparam ObjectType The object type + * \return The corresponding RefType + */ +template <typename BaseType, typename ObjectType> +inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr); + +/*! \brief ObjectRef hash functor */ +struct ObjectPtrHash { + size_t operator()(const ObjectRef& a) const { return operator()(a.data_); } + + template <typename T> + size_t operator()(const ObjectPtr<T>& a) const { + return std::hash<Object*>()(a.get()); + } +}; + +/*! \brief ObjectRef equal functor */ +struct ObjectPtrEqual { + bool operator()(const ObjectRef& a, const ObjectRef& b) const { return a.same_as(b); } + + template <typename T> + size_t operator()(const ObjectPtr<T>& a, const ObjectPtr<T>& b) const { + return a == b; + } +}; + +/*! + * \brief helper macro to declare a base object type that can be inherited. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_FFI_DECLARE_STATIC_OBJECT_INFO(TypeName, ParentType) \ + static_assert(!ParentType::_type_final, "ParentObj marked as final"); \ + static int32_t RuntimeTypeIndex() { \ + static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ + TypeName::_type_child_slots < ParentType::_type_child_slots, \ + "Need to set _type_child_slots when parent specifies it."); \ + static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \ + TypeName::_type_child_slots < ParentType::_type_child_slots, \ + "Need to set _type_child_slots when parent specifies it."); \ + static_assert(TypeName::_type_index != TypeIndex::kTVMFFIDynObject, \ + "Static object cannot have dynamic type index."); \ + return TypeName::_type_index; \ + } + +#define TVM_FFI_OBJECT_REG_VAR_DEF static TVM_ATTRIBUTE_UNUSED uint32_t __make_Object_tid + +/*! + * \brief Helper macro to register the object type to runtime. + * Makes sure that the runtime type table is correctly populated. + * + * Use this macro in the cc file for each terminal class. + */ +#define TVM_FFI_REGISTER_OBJECT_TYPE(TypeName) \ + TVM_FFI_STR_CONCAT(TVM_FFI_OBJECT_REG_VAR_DEF, __COUNTER__) = \ + TypeName::_GetOrAllocRuntimeTypeIndex() + +/* + * \brief Define object reference methods. + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \ + const ObjectName* get() const { return operator->(); } \ + using ContainerType = ObjectName; + +/* + * \brief Define object reference methods that is not nullable. + * + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + explicit TypeName(::tvm::ffi::ObjectPtr<::tvm::ffi::Object> n) : ParentType(n) {} \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \ + const ObjectName* get() const { return operator->(); } \ + static constexpr bool _type_is_nullable = false; \ + using ContainerType = ObjectName; + namespace details { /*! * \brief Namespace to internally manipulate object class. @@ -291,15 +467,14 @@ namespace details { * implementations and not external users of the tvm::ffi */ struct ObjectInternal { - // NOTE: these helper to perform static cast - // that also allows conversion from/to FFI values - template <typename T, typename U> - static TVM_FFI_INLINE T StaticCast(U src) { - return static_cast<T>(src); + // NOTE: get ffi header from an object + static TVM_FFI_INLINE TVMFFIObject* GetHeader(Object* src) { + return &(src->header_); } + // create ObjectPtr from unknowned ptr template <typename T> - static TVM_FFI_INLINE ObjectPtr<T> ObjectPtr(Object* raw_ptr) { + static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromUnowned(Object* raw_ptr) { return tvm::ffi::ObjectPtr<T>(raw_ptr); } }; diff --git a/ffi/src/ffi/registry.cc b/ffi/src/ffi/registry.cc index 698f81c2b6..e69de29bb2 100644 --- a/ffi/src/ffi/registry.cc +++ b/ffi/src/ffi/registry.cc @@ -1 +0,0 @@ -namespace tvm {} \ No newline at end of file diff --git a/ffi/tests/example/test_error.cc b/ffi/tests/example/test_error.cc new file mode 100644 index 0000000000..a2c073aa3f --- /dev/null +++ b/ffi/tests/example/test_error.cc @@ -0,0 +1,28 @@ +#include <gtest/gtest.h> +#include <tvm/ffi/error.h> + +namespace { + +using namespace tvm::ffi; + +void ThrowRuntimeError() { + TVM_FFI_THROW(RuntimeError) + << "test0"; +} + +TEST(Error, Traceback) { + EXPECT_THROW({ + try { + ThrowRuntimeError(); + } catch (const Error& error) { + EXPECT_EQ(error->message, "test0"); + EXPECT_EQ(error->kind, "RuntimeError"); + std::string what = error.what(); + EXPECT_NE(what.find("line"), std::string::npos); + EXPECT_NE(what.find("ThrowRuntimeError()"), std::string::npos); + EXPECT_NE(what.find("RuntimeError: test0"), std::string::npos); + throw; + } + }, ::tvm::ffi::Error); +} +} // namespace
