This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 98e38b33687c5612427daa7fd26fe36d1e7efb52 Author: tqchen <[email protected]> AuthorDate: Sun Sep 8 14:29:23 2024 -0400 [FFI][REFACTOR] Cleanup naming convention and lint --- ffi/CMakeLists.txt | 2 +- ffi/include/tvm/ffi/any.h | 29 +++--- .../tvm/ffi/{internal_utils.h => base_details.h} | 23 ++--- ffi/include/tvm/ffi/c_api.h | 3 +- ffi/include/tvm/ffi/container/array.h | 18 ++-- .../ffi/container/{base.h => container_details.h} | 17 ++-- ffi/include/tvm/ffi/container/optional.h | 0 ffi/include/tvm/ffi/error.h | 13 +-- ffi/include/tvm/ffi/function.h | 81 ++++++++------- ffi/include/tvm/ffi/function_details.h | 104 +++++++++---------- ffi/include/tvm/ffi/object.h | 12 +-- ffi/include/tvm/ffi/type_traits.h | 69 ++++++------- ffi/scripts/run_tests.sh | 6 +- ffi/src/ffi/traceback.h | 2 +- ffi/tests/example/test_error.cc | 33 +++---- ffi/tests/example/test_function.cc | 110 ++++++++++----------- ffi/tests/example/testing_object.h | 8 +- tests/lint/cpplint.sh | 1 + 18 files changed, 249 insertions(+), 282 deletions(-) diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index 7c830f8a0f..89429d253b 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -39,7 +39,7 @@ target_link_libraries(tvm_ffi INTERFACE dlpack_header) target_compile_features(tvm_ffi INTERFACE cxx_std_17) target_include_directories(tvm_ffi INTERFACE "${CMAKE_CURRENT_SOURCE_DIR}/include") -if (TVM_FFI_USE_LIBBRACKTRACE) +if (TVM_FFI_ALLOW_DYN_TYPE) message(STATUS "Setting C++ macro TVM_FFI_ALLOW_DYN_TYPE - 1") target_compile_definitions(tvm_ffi INTERFACE TVM_FFI_ALLOW_DYN_TYPE=1) else() diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 9912c45f12..3fad36eacc 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -26,6 +26,9 @@ #include <tvm/ffi/c_api.h> #include <tvm/ffi/type_traits.h> +#include <string> +#include <utility> + namespace tvm { namespace ffi { @@ -62,9 +65,7 @@ class AnyView { std::swap(data_, other.data_); } /*! \return the internal type index */ - int32_t type_index() const { - return data_.type_index; - } + int32_t type_index() const { return data_.type_index; } // default constructors AnyView() { data_.type_index = TypeIndex::kTVMFFINone; } ~AnyView() = default; @@ -80,7 +81,7 @@ class AnyView { // constructor from general types template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> AnyView(const T& other) { // NOLINT(*) - TypeTraits<T>::ConvertToAnyView(other, &data_); + TypeTraits<T>::CopyToAnyView(other, &data_); } template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> AnyView& operator=(const T& other) { // NOLINT(*) @@ -91,12 +92,12 @@ class AnyView { template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> std::optional<T> TryAs() const { - return TypeTraits<T>::TryConvertFromAnyView(&data_); + return TypeTraits<T>::TryCopyFromAnyView(&data_); } template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> operator T() const { - std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_); + std::optional<T> opt = TypeTraits<T>::TryCopyFromAnyView(&data_); if (opt.has_value()) { return std::move(*opt); } @@ -132,7 +133,7 @@ namespace details { */ TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny* data, [[maybe_unused]] size_t extra_any_bytes = 0) { - // TODO: string conversion. + // TODO(tqchen): string conversion. if (data->type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { details::ObjectUnsafe::IncRefObjectInAny(data); } @@ -165,9 +166,7 @@ class Any { std::swap(data_, other.data_); } /*! \return the internal type index */ - int32_t type_index() const { - return data_.type_index; - } + int32_t type_index() const { return data_.type_index; } // default constructors Any() { data_.type_index = TypeIndex::kTVMFFINone; } ~Any() { this->reset(); } @@ -189,7 +188,9 @@ class Any { return *this; } // convert from/to AnyView - Any(const AnyView& other) : data_(other.data_) { details::InplaceConvertAnyViewToAny(&data_); } + explicit Any(const AnyView& other) : data_(other.data_) { + details::InplaceConvertAnyViewToAny(&data_); + } Any& operator=(const AnyView& other) { // copy-and-swap idiom Any(other).swap(*this); // NOLINT(*) @@ -210,12 +211,12 @@ class Any { } template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> std::optional<T> TryAs() const { - return TypeTraits<T>::TryConvertFromAnyView(&data_); + return TypeTraits<T>::TryCopyFromAnyView(&data_); } template <typename T, typename = std::enable_if_t<TypeTraits<T>::enabled>> operator T() const { - std::optional<T> opt = TypeTraits<T>::TryConvertFromAnyView(&data_); + std::optional<T> opt = TypeTraits<T>::TryCopyFromAnyView(&data_); if (opt.has_value()) { return std::move(*opt); } @@ -254,7 +255,7 @@ struct AnyUnsafe : public ObjectUnsafe { template <typename T> static TVM_FFI_INLINE T ConvertAfterCheck(const Any& ref) { if constexpr (!std::is_same_v<T, Any>) { - return TypeTraits<T>::ConvertFromAnyViewAfterCheck(&(ref.data_)); + return TypeTraits<T>::CopyFromAnyViewAfterCheck(&(ref.data_)); } else { return ref; } diff --git a/ffi/include/tvm/ffi/internal_utils.h b/ffi/include/tvm/ffi/base_details.h similarity index 86% rename from ffi/include/tvm/ffi/internal_utils.h rename to ffi/include/tvm/ffi/base_details.h index ffe1dd9e1a..c44ad0608a 100644 --- a/ffi/include/tvm/ffi/internal_utils.h +++ b/ffi/include/tvm/ffi/base_details.h @@ -18,10 +18,12 @@ */ /*! * \file tvm/ffi/base_details.h - * \brief Internal use utilities + * \brief Internal detail utils that can be used by files in tvm/ffi. + * \note details header are for internal use only + * and not to be directly used by user. */ -#ifndef TVM_FFI_INTERNAL_UTILS_H_ -#define TVM_FFI_INTERNAL_UTILS_H_ +#ifndef TVM_FFI_BASE_DETAILS_H_ +#define TVM_FFI_BASE_DETAILS_H_ #include <tvm/ffi/c_api.h> @@ -74,21 +76,20 @@ * \param TypeName The class typename. */ #define TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - TypeName(const TypeName& other) = default; \ - TypeName(TypeName&& other) = default; \ - TypeName& operator=(const TypeName& other) = default; \ + TypeName(const TypeName& other) = default; \ + TypeName(TypeName&& other) = default; \ + TypeName& operator=(const TypeName& other) = default; \ TypeName& operator=(TypeName&& other) = default; namespace tvm { namespace ffi { - namespace details { /********** Atomic Operations *********/ TVM_FFI_INLINE int32_t AtomicIncrementRelaxed(int32_t* ptr) { #ifdef _MSC_VER - return _InterlockedIncrement(reinterpret_cast<volatile long*>(ptr)) - 1; + return _InterlockedIncrement(reinterpret_cast<volatile long*>(ptr)) - 1; // NOLINT(*) #else return __atomic_fetch_add(ptr, 1, __ATOMIC_RELAXED); #endif @@ -96,7 +97,7 @@ TVM_FFI_INLINE int32_t AtomicIncrementRelaxed(int32_t* ptr) { TVM_FFI_INLINE int32_t AtomicDecrementRelAcq(int32_t* ptr) { #ifdef _MSC_VER - return _InterlockedDecrement(reinterpret_cast<volatile long*>(ptr)) + 1; + return _InterlockedDecrement(reinterpret_cast<volatile long*>(ptr)) + 1; // NOLINT(*) #else return __atomic_fetch_sub(ptr, 1, __ATOMIC_ACQ_REL); #endif @@ -106,7 +107,7 @@ TVM_FFI_INLINE int32_t AtomicLoadRelaxed(const int32_t* ptr) { int32_t* raw_ptr = const_cast<int32_t*>(ptr); #ifdef _MSC_VER // simply load the variable ptr out - return (reinterpret_cast<const volatile long*>(raw_ptr))[0]; + return (reinterpret_cast<const volatile long*>(raw_ptr))[0]; // NOLINT(*) #else return __atomic_load_n(raw_ptr, __ATOMIC_RELAXED); #endif @@ -135,4 +136,4 @@ void for_each(const F& f, Args&&... args) { // NOLINT(*) } // namespace details } // namespace ffi } // namespace tvm -#endif // TVM_FFI_INTERNAL_UTILS_H_ +#endif // TVM_FFI_BASE_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 659ee6cd83..5c40fedda8 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -178,7 +178,8 @@ typedef struct { * It returns non-zero value if there is an error. * When error happens, the exception object will be stored in result. */ -typedef int (*TVMFFISafeCallType)(void* func, int32_t num_args, const TVMFFIAny* args, TVMFFIAny* result); +typedef int (*TVMFFISafeCallType)(void* func, int32_t num_args, const TVMFFIAny* args, + TVMFFIAny* result); #ifdef __cplusplus } // TVM_FFI_EXTERN_C diff --git a/ffi/include/tvm/ffi/container/array.h b/ffi/include/tvm/ffi/container/array.h index 2ce946a120..d4e3a88641 100644 --- a/ffi/include/tvm/ffi/container/array.h +++ b/ffi/include/tvm/ffi/container/array.h @@ -27,12 +27,12 @@ #define TVM_FFI_CONTAINER_ARRAY_H_ #include <tvm/ffi/any.h> -#include <tvm/ffi/container/base.h> +#include <tvm/ffi/container/container_details.h> #include <tvm/ffi/memory.h> #include <tvm/ffi/object.h> #include <algorithm> -#include <memory> +#include <string> #include <type_traits> #include <utility> #include <vector> @@ -41,7 +41,7 @@ namespace tvm { namespace ffi { /*! \brief array node content in array */ -class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, Any> { +class ArrayNode : public Object, public details::InplaceArrayBase<ArrayNode, Any> { public: /*! \return The size of the array */ size_t size() const { return this->size_; } @@ -391,8 +391,8 @@ class Array : public ObjectRef { static T convert(const Any& n) { return details::AnyUnsafe::ConvertAfterCheck<T>(n); } }; - using iterator = IterAdapter<ValueConverter, const Any*>; - using reverse_iterator = ReverseIterAdapter<ValueConverter, const Any*>; + using iterator = details::IterAdapter<ValueConverter, const Any*>; + using reverse_iterator = details::ReverseIterAdapter<ValueConverter, const Any*>; /*! \return begin iterator */ iterator begin() const { return iterator(GetArrayNode()->begin()); } @@ -946,7 +946,7 @@ inline constexpr bool use_default_type_traits_v<Array<T>> = false; template <typename T> struct TypeTraits<Array<T>> : public TypeTraitsBase { - static TVM_FFI_INLINE void ConvertToAnyView(const Array<T>& src, TVMFFIAny* result) { + static TVM_FFI_INLINE void CopyToAnyView(const Array<T>& src, TVMFFIAny* result) { TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetTVMFFIObjectPtrFromObjectRef(src); result->type_index = obj_ptr->type_index; result->v_obj = obj_ptr; @@ -992,13 +992,13 @@ struct TypeTraits<Array<T>> : public TypeTraitsBase { } } - static TVM_FFI_INLINE Array<T> ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE Array<T> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFINone) return Array<T>(nullptr); return Array<T>(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj)); } - static TVM_FFI_INLINE std::optional<Array<T>> TryConvertFromAnyView(const TVMFFIAny* src) { - if (CheckAnyView(src)) return ConvertFromAnyViewAfterCheck(src); + static TVM_FFI_INLINE std::optional<Array<T>> TryCopyFromAnyView(const TVMFFIAny* src) { + if (CheckAnyView(src)) return CopyFromAnyViewAfterCheck(src); return std::nullopt; } diff --git a/ffi/include/tvm/ffi/container/base.h b/ffi/include/tvm/ffi/container/container_details.h similarity index 95% rename from ffi/include/tvm/ffi/container/base.h rename to ffi/include/tvm/ffi/container/container_details.h index bfeeacd59b..9df7758524 100644 --- a/ffi/include/tvm/ffi/container/base.h +++ b/ffi/include/tvm/ffi/container/container_details.h @@ -18,21 +18,21 @@ */ /*! - * \file tvm/ffi/container/base.h - * \brief Base utilities for common POD(plain old data) container types. + * \file tvm/ffi/container/container_details.h + * \brief Common utilities for container types. */ -#ifndef TVM_FFI_CONTAINER_BASE_H_ -#define TVM_FFI_CONTAINER_BASE_H_ +#ifndef TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ +#define TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ #include <tvm/ffi/memory.h> #include <tvm/ffi/object.h> -#include <algorithm> -#include <initializer_list> +#include <type_traits> #include <utility> namespace tvm { namespace ffi { +namespace details { /*! * \brief Base template for classes with array like memory layout. * @@ -46,7 +46,7 @@ namespace ffi { * * \code * // Example usage of the template to define a simple array wrapper - * class ArrayNode : public InplaceArrayBase<ArrayNode, Elem> { + * class ArrayNode : public tvm::ffi::details::InplaceArrayBase<ArrayNode, Elem> { * public: * // Wrap EmplaceInit to initialize the elements * template <typename Iterator> @@ -263,6 +263,7 @@ class ReverseIterAdapter { private: TIter iter_; }; +} // namespace details } // namespace ffi } // namespace tvm -#endif // TVM_FFI_CONTAINER_BASE_H_ +#endif // TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/container/optional.h b/ffi/include/tvm/ffi/container/optional.h deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h index f894166ef3..1f703f1fac 100644 --- a/ffi/include/tvm/ffi/error.h +++ b/ffi/include/tvm/ffi/error.h @@ -24,8 +24,8 @@ #ifndef TVM_FFI_ERROR_H_ #define TVM_FFI_ERROR_H_ +#include <tvm/ffi/base_details.h> #include <tvm/ffi/c_api.h> -#include <tvm/ffi/internal_utils.h> #include <tvm/ffi/memory.h> #include <tvm/ffi/object.h> @@ -33,6 +33,7 @@ #include <memory> #include <sstream> #include <string> +#include <utility> /*! * \brief Macro defines whether we enable libbacktrace @@ -55,7 +56,7 @@ namespace ffi { /*! * \brief Error object class. */ -class ErrorObj: public Object { +class ErrorObj : public Object { public: /*! \brief The error kind */ std::string kind; @@ -76,9 +77,7 @@ class ErrorObj: public Object { * \brief Managed reference to ErrorObj * \sa Error Object */ -class Error : - public ObjectRef, - public std::exception { +class Error : public ObjectRef, public std::exception { public: Error(std::string kind, std::string message, std::string backtrace) { std::ostringstream what; @@ -91,9 +90,7 @@ class Error : data_ = std::move(n); } - const char* what() const noexcept(true) override { - return get()->what_str.c_str(); - } + const char* what() const noexcept(true) override { return get()->what_str.c_str(); } TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Error, ObjectRef, ErrorObj); }; diff --git a/ffi/include/tvm/ffi/function.h b/ffi/include/tvm/ffi/function.h index 81b0d8fab2..e2adee03c0 100644 --- a/ffi/include/tvm/ffi/function.h +++ b/ffi/include/tvm/ffi/function.h @@ -23,12 +23,15 @@ #ifndef TVM_FFI_FUNCTION_H_ #define TVM_FFI_FUNCTION_H_ -#include <tvm/ffi/c_api.h> -#include <tvm/ffi/internal_utils.h> #include <tvm/ffi/any.h> +#include <tvm/ffi/base_details.h> +#include <tvm/ffi/c_api.h> #include <tvm/ffi/error.h> #include <tvm/ffi/function_details.h> +#include <string> +#include <utility> + namespace tvm { namespace ffi { @@ -38,7 +41,7 @@ namespace ffi { */ class FunctionObj : public Object { public: - typedef void (*FCall)(const FunctionObj*, int32_t, const AnyView* , Any*); + typedef void (*FCall)(const FunctionObj*, int32_t, const AnyView*, Any*); /*! \brief A C++ style call implementation */ FCall call; /*! \brief A C API compatible call with exception catching. */ @@ -61,15 +64,14 @@ class FunctionObj : public Object { static int32_t SafeCall(void* func, int32_t num_args, const TVMFFIAny* args, TVMFFIAny* result) { FunctionObj* self = static_cast<FunctionObj*>(func); try { - self->call(self, num_args, reinterpret_cast<const AnyView*>(args), reinterpret_cast<Any*>(result)); + self->call(self, num_args, reinterpret_cast<const AnyView*>(args), + reinterpret_cast<Any*>(result)); return 0; } catch (const tvm::ffi::Error& err) { Any(std::move(err)).MoveToTVMFFIAny(result); return 1; } catch (const std::runtime_error& err) { - Any( - tvm::ffi::Error("RuntimeError", err.what(), "") - ).MoveToTVMFFIAny(result); + Any(tvm::ffi::Error("RuntimeError", err.what(), "")).MoveToTVMFFIAny(result); return 1; } TVM_FFI_UNREACHABLE(); @@ -94,8 +96,7 @@ class FunctionObjImpl : public FunctionObj { * \brief Derived object class for constructing PackedFuncObj. * \param callable The type-erased callable object. */ - explicit FunctionObjImpl(TCallable callable) - : callable_(callable) { + explicit FunctionObjImpl(TCallable callable) : callable_(callable) { this->call = Call; this->safe_call = SafeCall; } @@ -116,11 +117,10 @@ class FunctionObjImpl : public FunctionObj { */ template <typename Derived> struct RedirectCallToSafeCall { - static void Call(const FunctionObj* func, int32_t num_args, const AnyView* args , Any* rv) { + static void Call(const FunctionObj* func, int32_t num_args, const AnyView* args, Any* rv) { Derived* self = static_cast<Derived*>(const_cast<FunctionObj*>(func)); - int ret_code = self->RedirectSafeCall( - num_args, reinterpret_cast<const TVMFFIAny*>(args), - reinterpret_cast<TVMFFIAny*>(rv)); + int ret_code = self->RedirectSafeCall(num_args, reinterpret_cast<const TVMFFIAny*>(args), + reinterpret_cast<TVMFFIAny*>(rv)); if (ret_code != 0) { if (std::optional<tvm::ffi::Error> err = rv->TryAs<tvm::ffi::Error>()) { throw std::move(*err); @@ -139,23 +139,21 @@ struct RedirectCallToSafeCall { /*! * \brief FunctionObj specialization that leverages C-style callback definitions. */ -class ExternCFunctionObjImpl : - public FunctionObj, - public RedirectCallToSafeCall<ExternCFunctionObjImpl> { +class ExternCFunctionObjImpl : public FunctionObj, + public RedirectCallToSafeCall<ExternCFunctionObjImpl> { public: using RedirectCallToSafeCall<ExternCFunctionObjImpl>::SafeCall; ExternCFunctionObjImpl(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self)) - : self_(self), safe_call_(safe_call), deleter_(deleter) { + : self_(self), safe_call_(safe_call), deleter_(deleter) { this->call = RedirectCallToSafeCall<ExternCFunctionObjImpl>::Call; this->safe_call = RedirectCallToSafeCall<ExternCFunctionObjImpl>::SafeCall; } - ~ExternCFunctionObjImpl() { - deleter_(self_); - } + ~ExternCFunctionObjImpl() { deleter_(self_); } - TVM_FFI_INLINE int32_t RedirectSafeCall(int32_t num_args, const TVMFFIAny* args, TVMFFIAny* rv) const { + TVM_FFI_INLINE int32_t RedirectSafeCall(int32_t num_args, const TVMFFIAny* args, + TVMFFIAny* rv) const { return safe_call_(self_, num_args, args, rv); } @@ -168,19 +166,18 @@ class ExternCFunctionObjImpl : /*! * \brief FunctionObj specialization that wraps an external function. */ -class ImportedFunctionObjImpl : - public FunctionObj, - public RedirectCallToSafeCall<ImportedFunctionObjImpl> { +class ImportedFunctionObjImpl : public FunctionObj, + public RedirectCallToSafeCall<ImportedFunctionObjImpl> { public: using RedirectCallToSafeCall<ImportedFunctionObjImpl>::SafeCall; - explicit ImportedFunctionObjImpl(ObjectPtr<Object> data) - : data_(data) { + explicit ImportedFunctionObjImpl(ObjectPtr<Object> data) : data_(data) { this->call = RedirectCallToSafeCall<ImportedFunctionObjImpl>::Call; this->safe_call = RedirectCallToSafeCall<ImportedFunctionObjImpl>::SafeCall; } - TVM_FFI_INLINE int32_t RedirectSafeCall(int32_t num_args, const TVMFFIAny* args, TVMFFIAny* rv) const { + TVM_FFI_INLINE int32_t RedirectSafeCall(int32_t num_args, const TVMFFIAny* args, + TVMFFIAny* rv) const { FunctionObj* func = const_cast<FunctionObj*>(static_cast<const FunctionObj*>(data_.get())); return func->safe_call(func, num_args, args, rv); } @@ -192,12 +189,12 @@ class ImportedFunctionObjImpl : // Helper class to set packed arguments class PackedArgsSetter { public: - PackedArgsSetter(AnyView* args) : args_(args) {} + explicit PackedArgsSetter(AnyView* args) : args_(args) {} // NOTE: setter needs to be very carefully designed // such that we do not have temp variable conversion(eg. convert from lvalue to rvalue) // that is why we need T&& and std::forward here - template<typename T> + template <typename T> TVM_FFI_INLINE void operator()(size_t i, T&& value) const { args_[i].operator=(std::forward<T>(value)); } @@ -215,17 +212,17 @@ class Function : public ObjectRef { public: /*! \brief Constructor from null */ Function(std::nullptr_t) : ObjectRef(nullptr) {} // NOLINT(*) - /*! - * \brief Constructing a packed function from a callable type - * whose signature is consistent with `PackedFunc` - * \param packed_call The packed function signature - */ + /*! + * \brief Constructing a packed function from a callable type + * whose signature is consistent with `PackedFunc` + * \param packed_call The packed function signature + */ template <typename TCallable> static Function FromPacked(TCallable packed_call) { static_assert( - std::is_convertible_v<TCallable, std::function<void(int32_t, const AnyView*, Any*)>>, - "tvm::ffi::Function::FromPacked requires input function signature to match packed func format" - ); + std::is_convertible_v<TCallable, std::function<void(int32_t, const AnyView*, Any*)>>, + "tvm::ffi::Function::FromPacked requires input function signature to match packed func " + "format"); using ObjType = typename details::FunctionObjImpl<TCallable>; Function func; func.data_ = make_object<ObjType>(std::forward<TCallable>(packed_call)); @@ -261,7 +258,8 @@ class Function : public ObjectRef { * \param deleter The deleter to release the resource of self. * \return The created function. */ - static Function FromExternC(void* self, TVMFFISafeCallType safe_call, void (*deleter)(void* self)) { + static Function FromExternC(void* self, TVMFFISafeCallType safe_call, + void (*deleter)(void* self)) { // the other function coems from a different library Function func; func.data_ = make_object<details::ExternCFunctionObjImpl>(self, safe_call, deleter); @@ -276,7 +274,8 @@ class Function : public ObjectRef { static Function FromUnpacked(TCallable callable) { using FuncInfo = details::FunctionInfo<TCallable>; auto call_packed = [callable](int32_t num_args, const AnyView* args, Any* rv) -> void { - details::unpack_call<typename FuncInfo::RetType, FuncInfo::num_args>(nullptr, callable, num_args, args, rv); + details::unpack_call<typename FuncInfo::RetType, FuncInfo::num_args>(nullptr, callable, + num_args, args, rv); }; return FromPacked(call_packed); } @@ -326,7 +325,7 @@ class Function : public ObjectRef { * \param rv The return value. */ TVM_FFI_INLINE void CallPacked(int32_t num_args, const AnyView* args, Any* result) const { - static_cast<FunctionObj*>(data_.get())->CallPacked(num_args, args, result); + static_cast<FunctionObj*>(data_.get())->CallPacked(num_args, args, result); } /*! \return Whether the packed function is nullptr */ bool operator==(std::nullptr_t) const { return data_ == nullptr; } @@ -338,4 +337,4 @@ class Function : public ObjectRef { } // namespace ffi } // namespace tvm -#endif // TVM_FFI_OBJECT_H_ +#endif // TVM_FFI_FUNCTION_H_ diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h index f22258029d..56b7ca7df3 100644 --- a/ffi/include/tvm/ffi/function_details.h +++ b/ffi/include/tvm/ffi/function_details.h @@ -23,40 +23,36 @@ #ifndef TVM_FFI_FUNCTION_DETAILS_H_ #define TVM_FFI_FUNCTION_DETAILS_H_ -#include <tvm/ffi/c_api.h> -#include <tvm/ffi/internal_utils.h> #include <tvm/ffi/any.h> +#include <tvm/ffi/base_details.h> +#include <tvm/ffi/c_api.h> #include <tvm/ffi/error.h> +#include <string> +#include <tuple> +#include <utility> + namespace tvm { namespace ffi { namespace details { template <typename Type> struct Type2Str { - static std::string v() { - return TypeTraitsNoCR<Type>::TypeStr(); - } + static std::string v() { return TypeTraitsNoCR<Type>::TypeStr(); } }; template <> struct Type2Str<Any> { - static const char* v() { - return "Any"; - } + static const char* v() { return "Any"; } }; template <> struct Type2Str<AnyView> { - static const char* v() { - return "AnyView"; - } + static const char* v() { return "AnyView"; } }; template <> struct Type2Str<void> { - static const char* v() { - return "void"; - } + static const char* v() { return "void"; } }; template <typename ArgType> @@ -70,24 +66,22 @@ struct Arg2Str { os << i << ": " << Type2Str<Arg>::v(); } template <size_t... I> - static TVM_FFI_INLINE void Run(std::ostream &os, std::index_sequence<I...>) { + static TVM_FFI_INLINE void Run(std::ostream& os, std::index_sequence<I...>) { using TExpander = int[]; (void)TExpander{0, (Apply<I>(os), 0)...}; } }; template <typename T> -static constexpr bool ArgSupported = ( - std::is_same_v<std::remove_const_t<std::remove_reference_t<T>>, Any> || - std::is_same_v<std::remove_const_t<std::remove_reference_t<T>>, AnyView> || - TypeTraitsNoCR<T>::enabled -); +static constexpr bool ArgSupported = + (std::is_same_v<std::remove_const_t<std::remove_reference_t<T>>, Any> || + std::is_same_v<std::remove_const_t<std::remove_reference_t<T>>, AnyView> || + TypeTraitsNoCR<T>::enabled); // NOTE: return type can only support non-reference managed returns template <typename T> -static constexpr bool RetSupported = ( - std::is_same_v<T, Any> || std::is_void_v<T> || TypeTraits<T>::enabled -); +static constexpr bool RetSupported = + (std::is_same_v<T, Any> || std::is_void_v<T> || TypeTraits<T>::enabled); template <typename R, typename... Args> struct FuncFunctorImpl { @@ -113,9 +107,9 @@ template <typename T> struct FunctionInfoHelper; template <typename T, typename R, typename... Args> -struct FunctionInfoHelper<R (T::*)(Args...)>: FuncFunctorImpl<R, Args...> {}; +struct FunctionInfoHelper<R (T::*)(Args...)> : FuncFunctorImpl<R, Args...> {}; template <typename T, typename R, typename... Args> -struct FunctionInfoHelper<R (T::*)(Args...) const>: FuncFunctorImpl<R, Args...> {}; +struct FunctionInfoHelper<R (T::*)(Args...) const> : FuncFunctorImpl<R, Args...> {}; /*! * \brief Template class to get function signature of a function or functor. @@ -133,15 +127,15 @@ struct FunctionInfo<R (*)(Args...)> : FuncFunctorImpl<R, Args...> {}; /*! \brief Using static function to output TypedPackedFunc signature */ typedef std::string (*FGetFuncSignature)(); -template<typename T> +template <typename T> TVM_FFI_INLINE std::optional<T> TryAs(AnyView arg) { return arg.TryAs<T>(); } -template<> +template <> TVM_FFI_INLINE std::optional<Any> TryAs<Any>(AnyView arg) { return Any(arg); } -template<> +template <> TVM_FFI_INLINE std::optional<AnyView> TryAs<AnyView>(AnyView arg) { return arg; } @@ -159,14 +153,10 @@ class MovableArgValueWithContext { * \param f_sig Pointer to static function outputting signature of the function being called. * named. */ - TVM_FFI_INLINE MovableArgValueWithContext( - const AnyView* args, int32_t arg_index, - const std::string* optional_name, - FGetFuncSignature f_sig) - : args_(args), - arg_index_(arg_index), - optional_name_(optional_name), - f_sig_(f_sig) {} + TVM_FFI_INLINE MovableArgValueWithContext(const AnyView* args, int32_t arg_index, + const std::string* optional_name, + FGetFuncSignature f_sig) + : args_(args), arg_index_(arg_index), optional_name_(optional_name), f_sig_(f_sig) {} template <typename Type> TVM_FFI_INLINE operator Type() { @@ -175,12 +165,11 @@ class MovableArgValueWithContext { if (opt.has_value()) { return std::move(*opt); } - TVM_FFI_THROW(TypeError) - << "Mismatched type on argument #" << arg_index_ << " when calling: `" - << (optional_name_ == nullptr ? "" : *optional_name_) - << (f_sig_ == nullptr ? "" : (*f_sig_)()) << "`. Expected `" - << Type2Str<Type>::v() << "` but got `" - << TypeIndex2TypeKey(args_[arg_index_].type_index()) << "`"; + TVM_FFI_THROW(TypeError) << "Mismatched type on argument #" << arg_index_ << " when calling: `" + << (optional_name_ == nullptr ? "" : *optional_name_) + << (f_sig_ == nullptr ? "" : (*f_sig_)()) << "`. Expected `" + << Type2Str<Type>::v() << "` but got `" + << TypeIndex2TypeKey(args_[arg_index_].type_index()) << "`"; } private: @@ -193,8 +182,8 @@ class MovableArgValueWithContext { template <typename R, int nleft, int index, typename F> struct unpack_call_dispatcher { template <typename... Args> - TVM_FFI_INLINE static void run(const std::string* optional_name, FGetFuncSignature f_sig, const F& f, - int32_t num_args, const AnyView* args, Any* rv, + TVM_FFI_INLINE static void run(const std::string* optional_name, FGetFuncSignature f_sig, + const F& f, int32_t num_args, const AnyView* args, Any* rv, Args&&... unpacked_args) { // construct a movable argument value // which allows potential move of argument to the input of F. @@ -207,9 +196,8 @@ struct unpack_call_dispatcher { template <typename R, int index, typename F> struct unpack_call_dispatcher<R, 0, index, F> { template <typename... Args> - TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature, const F& f, - int32_t, const AnyView*, Any* rv, - Args&&... unpacked_args) { + TVM_FFI_INLINE static void run(const std::string*, FGetFuncSignature, const F& f, int32_t, + const AnyView*, Any* rv, Args&&... unpacked_args) { using RetType = decltype(f(std::forward<Args>(unpacked_args)...)); if constexpr (std::is_same_v<RetType, R>) { *rv = f(std::forward<Args>(unpacked_args)...); @@ -222,25 +210,25 @@ struct unpack_call_dispatcher<R, 0, index, F> { template <int index, typename F> struct unpack_call_dispatcher<void, 0, index, F> { template <typename... Args> - TVM_FFI_INLINE static void run(const std::string* optional_name, FGetFuncSignature f_sig, const F& f, - int32_t num_args, const AnyView* args, Any* rv, - Args&&... unpacked_args) { + TVM_FFI_INLINE static void run(const std::string* optional_name, FGetFuncSignature f_sig, + const F& f, int32_t num_args, const AnyView* args, Any* rv, + Args&&... unpacked_args) { f(std::forward<Args>(unpacked_args)...); } }; template <typename R, int nargs, typename F> -TVM_FFI_INLINE void unpack_call(const std::string* optional_name, const F& f, - int32_t num_args, const AnyView* args, Any* rv) { +TVM_FFI_INLINE void unpack_call(const std::string* optional_name, const F& f, int32_t num_args, + const AnyView* args, Any* rv) { using FuncInfo = FunctionInfo<F>; FGetFuncSignature f_sig = FuncInfo::Sig; - static_assert(FuncInfo::unpacked_supported, "The function signature cannot support unpacked call"); + static_assert(FuncInfo::unpacked_supported, + "The function signature cannot support unpacked call"); if (nargs != num_args) { - TVM_FFI_THROW(TypeError) - << "Mismatched number of arguments when calling: `" - << (optional_name == nullptr ? "" : *optional_name) - << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " - << nargs << " but got " << num_args << " arguments"; + TVM_FFI_THROW(TypeError) << "Mismatched number of arguments when calling: `" + << (optional_name == nullptr ? "" : *optional_name) + << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " << nargs + << " but got " << num_args << " arguments"; } unpack_call_dispatcher<R, nargs, 0, F>::run(optional_name, f_sig, f, num_args, args, rv); } diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 13de74fa07..9e218899d7 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -23,8 +23,8 @@ #ifndef TVM_FFI_OBJECT_H_ #define TVM_FFI_OBJECT_H_ +#include <tvm/ffi/base_details.h> #include <tvm/ffi/c_api.h> -#include <tvm/ffi/internal_utils.h> #include <type_traits> #include <utility> @@ -336,7 +336,6 @@ class ObjectPtr { friend ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr); }; - // Forward declaration, to prevent circular includes. template <typename T> class Optional; @@ -469,9 +468,9 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr); * \param TypeName The name of the current type. * \param ParentType The name of the ParentType */ -#define TVM_FFI_DECLARE_STATIC_OBJECT_INFO(TypeName, ParentType) \ - TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType); \ - static int32_t RuntimeTypeIndex() { return TypeName::_type_index; }\ +#define TVM_FFI_DECLARE_STATIC_OBJECT_INFO(TypeName, ParentType) \ + TVM_FFI_REGISTER_STATIC_TYPE_INFO(TypeName, ParentType); \ + static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \ TVM_FFI_OBJECT_STATIC_DEFS(TypeName, ParentType) /*! @@ -499,10 +498,9 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjectType* ptr); */ #define TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \ static const constexpr int _type_child_slots = 0; \ - static const constexpr bool _type_final = true; \ + static const constexpr bool _type_final = true; \ TVM_FFI_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) - /* * \brief Define object reference methods. * \param TypeName The object type name diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index be46639642..9fc38f7354 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -23,12 +23,13 @@ #ifndef TVM_FFI_TYPE_TRAITS_H_ #define TVM_FFI_TYPE_TRAITS_H_ +#include <tvm/ffi/base_details.h> #include <tvm/ffi/c_api.h> #include <tvm/ffi/error.h> -#include <tvm/ffi/internal_utils.h> #include <tvm/ffi/object.h> #include <optional> +#include <string> #include <type_traits> namespace tvm { @@ -72,11 +73,11 @@ inline std::string TypeIndex2TypeKey(int32_t type_index) { * * We need to implement the following conversion functions * - * - void ConvertToAnyView(const T& src, TVMFFIAny* result); + * - void CopyToAnyView(const T& src, TVMFFIAny* result); * * Convert a value to AnyView * - * - std::optional<T> TryConvertFromAnyView(const TVMFFIAny* src); + * - std::optional<T> TryCopyFromAnyView(const TVMFFIAny* src); * * Try convert AnyView to a value type. */ @@ -99,7 +100,7 @@ struct TypeTraitsBase { static constexpr bool enabled = true; // get mismatched type when result mismatches the trait. - // this function is called after TryConvertFromAnyView fails + // this function is called after TryCopyFromAnyView fails // to get more detailed type information in runtime // especially when the error involves nested container type static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* source) { @@ -110,7 +111,7 @@ struct TypeTraitsBase { // None template <> struct TypeTraits<std::nullptr_t> : public TypeTraitsBase { - static TVM_FFI_INLINE void ConvertToAnyView(const std::nullptr_t&, TVMFFIAny* result) { + static TVM_FFI_INLINE void CopyToAnyView(const std::nullptr_t&, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFINone; // invariant: the pointer field also equals nullptr // this will simplify the recovery of nullable object from the any @@ -124,7 +125,7 @@ struct TypeTraits<std::nullptr_t> : public TypeTraitsBase { result->v_int64 = 0; } - static TVM_FFI_INLINE std::optional<std::nullptr_t> TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional<std::nullptr_t> TryCopyFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFINone) { return nullptr; } @@ -135,7 +136,7 @@ struct TypeTraits<std::nullptr_t> : public TypeTraitsBase { return src->type_index == TypeIndex::kTVMFFINone; } - static TVM_FFI_INLINE std::nullptr_t ConvertFromAnyViewAfterCheck(const TVMFFIAny*) { + static TVM_FFI_INLINE std::nullptr_t CopyFromAnyViewAfterCheck(const TVMFFIAny*) { return nullptr; } @@ -145,16 +146,14 @@ struct TypeTraits<std::nullptr_t> : public TypeTraitsBase { // Integer POD values template <typename Int> struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> : public TypeTraitsBase { - static TVM_FFI_INLINE void ConvertToAnyView(const Int& src, TVMFFIAny* result) { + static TVM_FFI_INLINE void CopyToAnyView(const Int& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIInt; result->v_int64 = static_cast<int64_t>(src); } - static TVM_FFI_INLINE void MoveToAny(Int src, TVMFFIAny* result) { - ConvertToAnyView(src, result); - } + static TVM_FFI_INLINE void MoveToAny(Int src, TVMFFIAny* result) { CopyToAnyView(src, result); } - static TVM_FFI_INLINE std::optional<Int> TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional<Int> TryCopyFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIInt) { return std::make_optional<Int>(src->v_int64); } @@ -165,7 +164,7 @@ struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> : public TypeT return src->type_index == TypeIndex::kTVMFFIInt; } - static TVM_FFI_INLINE int ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE int CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return static_cast<Int>(src->v_int64); } @@ -176,16 +175,14 @@ struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> : public TypeT template <typename Float> struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> : public TypeTraitsBase { - static TVM_FFI_INLINE void ConvertToAnyView(const Float& src, TVMFFIAny* result) { + static TVM_FFI_INLINE void CopyToAnyView(const Float& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIFloat; result->v_float64 = static_cast<double>(src); } - static TVM_FFI_INLINE void MoveToAny(Float src, TVMFFIAny* result) { - ConvertToAnyView(src, result); - } + static TVM_FFI_INLINE void MoveToAny(Float src, TVMFFIAny* result) { CopyToAnyView(src, result); } - static TVM_FFI_INLINE std::optional<Float> TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional<Float> TryCopyFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIFloat) { return std::make_optional<Float>(src->v_float64); } else if (src->type_index == TypeIndex::kTVMFFIInt) { @@ -198,7 +195,7 @@ struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> return src->type_index == TypeIndex::kTVMFFIFloat || src->type_index == TypeIndex::kTVMFFIInt; } - static TVM_FFI_INLINE Float ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE Float CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIFloat) { return static_cast<Float>(src->v_float64); } else { @@ -212,16 +209,14 @@ struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> // void* template <> struct TypeTraits<void*> : public TypeTraitsBase { - static TVM_FFI_INLINE void ConvertToAnyView(void* src, TVMFFIAny* result) { + static TVM_FFI_INLINE void CopyToAnyView(void* src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIOpaquePtr; result->v_ptr = src; } - static TVM_FFI_INLINE void MoveToAny(void* src, TVMFFIAny* result) { - ConvertToAnyView(src, result); - } + static TVM_FFI_INLINE void MoveToAny(void* src, TVMFFIAny* result) { CopyToAnyView(src, result); } - static TVM_FFI_INLINE std::optional<void*> TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional<void*> TryCopyFromAnyView(const TVMFFIAny* src) { if (src->type_index == TypeIndex::kTVMFFIOpaquePtr) { return std::make_optional<void*>(src->v_ptr); } @@ -236,9 +231,7 @@ struct TypeTraits<void*> : public TypeTraitsBase { src->type_index == TypeIndex::kTVMFFINone; } - static TVM_FFI_INLINE void* ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { - return src->v_ptr; - } + static TVM_FFI_INLINE void* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return src->v_ptr; } static TVM_FFI_INLINE std::string TypeStr() { return "void*"; } }; @@ -250,7 +243,7 @@ struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef : public TypeTraitsBase { using ContainerType = typename TObjRef::ContainerType; - static TVM_FFI_INLINE void ConvertToAnyView(const TObjRef& src, TVMFFIAny* result) { + static TVM_FFI_INLINE void CopyToAnyView(const TObjRef& src, TVMFFIAny* result) { TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetTVMFFIObjectPtrFromObjectRef(src); result->type_index = obj_ptr->type_index; result->v_obj = obj_ptr; @@ -268,14 +261,14 @@ struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef (src->type_index == kTVMFFINone && TObjRef::_type_is_nullable); } - static TVM_FFI_INLINE TObjRef ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE TObjRef CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { if constexpr (TObjRef::_type_is_nullable) { if (src->type_index == kTVMFFINone) return TObjRef(nullptr); } return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj)); } - static TVM_FFI_INLINE std::optional<TObjRef> TryConvertFromAnyView(const TVMFFIAny* src) { + static TVM_FFI_INLINE std::optional<TObjRef> TryCopyFromAnyView(const TVMFFIAny* src) { if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { if (details::IsObjectInstance<ContainerType>(src->type_index)) { return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj)); @@ -293,7 +286,7 @@ struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef // Traits for ObjectPtr template <typename T> struct TypeTraits<ObjectPtr<T>> : public TypeTraitsBase { - static TVM_FFI_INLINE void ConvertToAnyView(const ObjectPtr<T>& src, TVMFFIAny* result) { + static TVM_FFI_INLINE void CopyToAnyView(const ObjectPtr<T>& src, TVMFFIAny* result) { TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetTVMFFIObjectPtrFromObjectPtr(src); result->type_index = obj_ptr->type_index; result->v_obj = obj_ptr; @@ -310,12 +303,12 @@ struct TypeTraits<ObjectPtr<T>> : public TypeTraitsBase { details::IsObjectInstance<T>(src->type_index); } - static TVM_FFI_INLINE ObjectPtr<T> ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE ObjectPtr<T> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return details::ObjectUnsafe::ObjectPtrFromUnowned<T>(src->v_obj); } - static TVM_FFI_INLINE std::optional<ObjectPtr<T>> TryConvertFromAnyView(const TVMFFIAny* src) { - if (CheckAnyView(src)) return ConvertFromAnyViewAfterCheck(src); + static TVM_FFI_INLINE std::optional<ObjectPtr<T>> TryCopyFromAnyView(const TVMFFIAny* src) { + if (CheckAnyView(src)) return CopyFromAnyViewAfterCheck(src); return std::nullopt; } @@ -326,7 +319,7 @@ struct TypeTraits<ObjectPtr<T>> : public TypeTraitsBase { template <typename TObject> struct TypeTraits<const TObject*, std::enable_if_t<std::is_base_of_v<Object, TObject>>> : public TypeTraitsBase { - static TVM_FFI_INLINE void ConvertToAnyView(const TObject* src, TVMFFIAny* result) { + static TVM_FFI_INLINE void CopyToAnyView(const TObject* src, TVMFFIAny* result) { TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); result->type_index = obj_ptr->type_index; result->v_obj = obj_ptr; @@ -345,12 +338,12 @@ struct TypeTraits<const TObject*, std::enable_if_t<std::is_base_of_v<Object, TOb details::IsObjectInstance<TObject>(src->type_index); } - static TVM_FFI_INLINE const TObject* ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + static TVM_FFI_INLINE const TObject* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { return reinterpret_cast<const TObject*>(src->v_obj); } - static TVM_FFI_INLINE std::optional<const TObject*> TryConvertFromAnyView(const TVMFFIAny* src) { - if (CheckAnyView(src)) return ConvertFromAnyViewAfterCheck(src); + static TVM_FFI_INLINE std::optional<const TObject*> TryCopyFromAnyView(const TVMFFIAny* src) { + if (CheckAnyView(src)) return CopyFromAnyViewAfterCheck(src); return std::nullopt; } diff --git a/ffi/scripts/run_tests.sh b/ffi/scripts/run_tests.sh index 0d4efe6bf2..8219afc90d 100755 --- a/ffi/scripts/run_tests.sh +++ b/ffi/scripts/run_tests.sh @@ -1,11 +1,11 @@ #!/bin/bash set -euxo pipefail -HEADER_ONLY=OFF -BUILD_TYPE=RelWithDebInfo +TVM_FFI_ALLOW_DYN_TYPE=ON +BUILD_TYPE=Release rm -rf build/CMakeFiles build/CMakeCache.txt -cmake -G Ninja -S . -B build -DTVM_FFI_ALLOW_DYN_TYPE=${HEADER_ONLY} -DTVM_FFI_BUILD_TESTS=ON -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ +cmake -G Ninja -S . -B build -DTVM_FFI_ALLOW_DYN_TYPE=${TVM_FFI_ALLOW_DYN_TYPE} -DTVM_FFI_BUILD_TESTS=ON -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ -DTVM_FFI_BUILD_REGISTRY=ON \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_CXX_COMPILER_LAUNCHER=ccache cmake --build build --parallel 16 --clean-first --config ${BUILD_TYPE} --target tvm_ffi_tests diff --git a/ffi/src/ffi/traceback.h b/ffi/src/ffi/traceback.h index e4ba2a589d..d2584081a7 100644 --- a/ffi/src/ffi/traceback.h +++ b/ffi/src/ffi/traceback.h @@ -130,4 +130,4 @@ struct TracebackStorage { } // namespace ffi } // namespace tvm -#endif +#endif // TVM_FFI_TRACEBACK_H_ diff --git a/ffi/tests/example/test_error.cc b/ffi/tests/example/test_error.cc index bf43f97ad5..0e8ce28b44 100644 --- a/ffi/tests/example/test_error.cc +++ b/ffi/tests/example/test_error.cc @@ -23,25 +23,24 @@ namespace { using namespace tvm::ffi; -void ThrowRuntimeError() { - TVM_FFI_THROW(RuntimeError) - << "test0"; -} +void ThrowRuntimeError() { TVM_FFI_THROW(RuntimeError) << "test0"; } TEST(Error, Traceback) { - EXPECT_THROW({ - try { - ThrowRuntimeError(); - } catch (const Error& error) { - EXPECT_EQ(error->message, "test0"); - EXPECT_EQ(error->kind, "RuntimeError"); - std::string what = error.what(); - EXPECT_NE(what.find("line"), std::string::npos); - EXPECT_NE(what.find("ThrowRuntimeError"), std::string::npos); - EXPECT_NE(what.find("RuntimeError: test0"), std::string::npos); - throw; - } - }, ::tvm::ffi::Error); + EXPECT_THROW( + { + try { + ThrowRuntimeError(); + } catch (const Error& error) { + EXPECT_EQ(error->message, "test0"); + EXPECT_EQ(error->kind, "RuntimeError"); + std::string what = error.what(); + EXPECT_NE(what.find("line"), std::string::npos); + EXPECT_NE(what.find("ThrowRuntimeError"), std::string::npos); + EXPECT_NE(what.find("RuntimeError: test0"), std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); } TEST(CheckError, Traceback) { diff --git a/ffi/tests/example/test_function.cc b/ffi/tests/example/test_function.cc index 46c69c3739..99e7819b23 100644 --- a/ffi/tests/example/test_function.cc +++ b/ffi/tests/example/test_function.cc @@ -18,8 +18,8 @@ */ #include <gtest/gtest.h> #include <tvm/ffi/any.h> -#include <tvm/ffi/memory.h> #include <tvm/ffi/function.h> +#include <tvm/ffi/memory.h> #include "./testing_object.h" @@ -29,95 +29,87 @@ using namespace tvm::ffi; using namespace tvm::ffi::testing; TEST(Func, FromPacked) { - Function fadd1 = Function::FromPacked( - [](int32_t num_args, const AnyView* args, Any* rv) { - EXPECT_EQ(num_args, 1); - int32_t a = args[0]; - *rv = a + 1; - } - ); + Function fadd1 = Function::FromPacked([](int32_t num_args, const AnyView* args, Any* rv) { + EXPECT_EQ(num_args, 1); + int32_t a = args[0]; + *rv = a + 1; + }); int b = fadd1(1); EXPECT_EQ(b, 2); - Function fadd2 = Function::FromPacked( - [](int32_t num_args, const AnyView* args, Any* rv) { - EXPECT_EQ(num_args, 1); - TInt a = args[0]; - EXPECT_EQ(a.use_count(), 2); - *rv = a->value + 1; - } - ); + Function fadd2 = Function::FromPacked([](int32_t num_args, const AnyView* args, Any* rv) { + EXPECT_EQ(num_args, 1); + TInt a = args[0]; + EXPECT_EQ(a.use_count(), 2); + *rv = a->value + 1; + }); EXPECT_EQ(fadd2(TInt(12)).operator int(), 13); } TEST(Func, FromUnpacked) { // try decution - Function fadd1 = Function::FromUnpacked( - [](const int32_t& a) -> int { return a + 1; } - ); + Function fadd1 = Function::FromUnpacked([](const int32_t& a) -> int { return a + 1; }); int b = fadd1(1); EXPECT_EQ(b, 2); - // convert that triggers error + // convert that triggers error EXPECT_THROW( { try { - fadd1(1.1); + fadd1(1.1); } catch (const Error& error) { EXPECT_EQ(error->kind, "TypeError"); - EXPECT_STREQ( - error->message.c_str(), - "Mismatched type on argument #0 when calling: `(0: int) -> int`. " - "Expected `int` but got `float`"); + EXPECT_STREQ(error->message.c_str(), + "Mismatched type on argument #0 when calling: `(0: int) -> int`. " + "Expected `int` but got `float`"); throw; } }, ::tvm::ffi::Error); - // convert that triggers error - EXPECT_THROW( + // convert that triggers error + EXPECT_THROW( { try { - fadd1(); + fadd1(); } catch (const Error& error) { EXPECT_EQ(error->kind, "TypeError"); - EXPECT_STREQ( - error->message.c_str(), - "Mismatched number of arguments when calling: `(0: int) -> int`. " - "Expected 1 but got 0 arguments"); + EXPECT_STREQ(error->message.c_str(), + "Mismatched number of arguments when calling: `(0: int) -> int`. " + "Expected 1 but got 0 arguments"); throw; } }, ::tvm::ffi::Error); // try decution - Function fpass_and_return = Function::FromUnpacked( - [](TInt x, int value, AnyView z) -> Function { - EXPECT_EQ(x.use_count(), 2); - EXPECT_EQ(x->value, value); - if (auto opt = z.TryAs<int>()) { - EXPECT_EQ(value, *opt); - } - return Function::FromUnpacked([value](int x) -> int { return x + value; }); - }, - "fpass_and_return"); - TInt a(11); - Function fret = fpass_and_return(std::move(a), 11, 11); - EXPECT_EQ(fret(12).operator int(), 23); + Function fpass_and_return = Function::FromUnpacked( + [](TInt x, int value, AnyView z) -> Function { + EXPECT_EQ(x.use_count(), 2); + EXPECT_EQ(x->value, value); + if (auto opt = z.TryAs<int>()) { + EXPECT_EQ(value, *opt); + } + return Function::FromUnpacked([value](int x) -> int { return x + value; }); + }, + "fpass_and_return"); + TInt a(11); + Function fret = fpass_and_return(std::move(a), 11, 11); + EXPECT_EQ(fret(12).operator int(), 23); - EXPECT_THROW( - { - try { - fpass_and_return(); - } catch (const Error& error) { - EXPECT_EQ(error->kind, "TypeError"); - EXPECT_STREQ(error->message.c_str(), - "Mismatched number of arguments when calling: " - "`fpass_and_return(0: test.Int, 1: int, 2: AnyView) -> object.Function`. " - "Expected 3 but got 0 arguments"); - throw; - } - }, - ::tvm::ffi::Error); + EXPECT_THROW( + { + try { + fpass_and_return(); + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + EXPECT_STREQ(error->message.c_str(), + "Mismatched number of arguments when calling: " + "`fpass_and_return(0: test.Int, 1: int, 2: AnyView) -> object.Function`. " + "Expected 3 but got 0 arguments"); + throw; + } + }, + ::tvm::ffi::Error); } } // namespace diff --git a/ffi/tests/example/testing_object.h b/ffi/tests/example/testing_object.h index 00cc58b7b7..e660b2751e 100644 --- a/ffi/tests/example/testing_object.h +++ b/ffi/tests/example/testing_object.h @@ -51,9 +51,7 @@ class TIntObj : public TNumberObj { class TInt : public TNumber { public: - explicit TInt(int64_t value) { - data_ = make_object<TIntObj>(value); - } + explicit TInt(int64_t value) { data_ = make_object<TIntObj>(value); } TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TInt, TNumber, TIntObj); }; @@ -70,9 +68,7 @@ class TFloatObj : public TNumberObj { class TFloat : public TNumber { public: - explicit TFloat(double value) { - data_ = make_object<TFloatObj>(value); - } + explicit TFloat(double value) { data_ = make_object<TFloatObj>(value); } TVM_FFI_DEFINE_NULLABLE_OBJECT_REF_METHODS(TFloat, TNumber, TFloatObj); }; diff --git a/tests/lint/cpplint.sh b/tests/lint/cpplint.sh index 39b86937ad..b2b1bbdd99 100755 --- a/tests/lint/cpplint.sh +++ b/tests/lint/cpplint.sh @@ -19,6 +19,7 @@ set -e echo "Running 2 cpplints..." +python3 3rdparty/dmlc-core/scripts/lint.py --quiet tvm cpp ffi/include ffi/src python3 3rdparty/dmlc-core/scripts/lint.py --quiet tvm cpp \ include src \ examples/extension/src examples/graph_executor/src \
