This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 38ffb877689724a4a14c49add649a6fff2646160 Author: tqchen <[email protected]> AuthorDate: Sat Sep 7 19:05:46 2024 -0400 [FFI] Introduce array support --- ffi/include/tvm/ffi/any.h | 46 +- ffi/include/tvm/ffi/c_api.h | 2 +- ffi/include/tvm/ffi/container/array.h | 1010 ++++++++++++++++++++++++++++++ ffi/include/tvm/ffi/container/base.h | 268 ++++++++ ffi/include/tvm/ffi/container/optional.h | 0 ffi/include/tvm/ffi/memory.h | 9 +- ffi/include/tvm/ffi/object.h | 2 +- ffi/include/tvm/ffi/type_traits.h | 230 ++++--- ffi/tests/example/test_array.cc | 231 +++++++ 9 files changed, 1703 insertions(+), 95 deletions(-) diff --git a/ffi/include/tvm/ffi/any.h b/ffi/include/tvm/ffi/any.h index 2a2d223c57..9912c45f12 100644 --- a/ffi/include/tvm/ffi/any.h +++ b/ffi/include/tvm/ffi/any.h @@ -31,6 +31,12 @@ namespace ffi { class Any; +namespace details { +// Helper to perform +// unsafe operations related to object +struct AnyUnsafe; +} // namespace details + /*! * \brief AnyView allows us to take un-managed reference view of any value. */ @@ -94,8 +100,9 @@ class AnyView { if (opt.has_value()) { return std::move(*opt); } - TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << TypeIndex2TypeKey(data_.type_index) - << "` to `" << TypeTraits<T>::TypeStr() << "`"; + TVM_FFI_THROW(TypeError) << "Cannot convert from type `" + << TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `" + << TypeTraits<T>::TypeStr() << "`"; TVM_FFI_UNREACHABLE(); } // The following functions are only used for testing purposes @@ -212,8 +219,9 @@ class Any { if (opt.has_value()) { return std::move(*opt); } - TVM_FFI_THROW(TypeError) << "Cannot convert from type `" << TypeIndex2TypeKey(data_.type_index) - << "` to `" << TypeTraits<T>::TypeStr() << "`"; + TVM_FFI_THROW(TypeError) << "Cannot convert from type `" + << TypeTraits<T>::GetMismatchTypeInfo(&data_) << "` to `" + << TypeTraits<T>::TypeStr() << "`"; TVM_FFI_UNREACHABLE(); } @@ -226,12 +234,42 @@ class Any { *result = data_; data_.type_index = TypeIndex::kTVMFFINone; } + + friend class details::AnyUnsafe; }; // layout assert to ensure we can freely cast between the two types static_assert(sizeof(AnyView) == sizeof(TVMFFIAny)); static_assert(sizeof(Any) == sizeof(TVMFFIAny)); +namespace details { +// Extra unsafe method to help any manipulation +struct AnyUnsafe : public ObjectUnsafe { + /*! + * \brief Internal helper function downcast a any that already passes check. + * \note Only used for internal dev purposes. + * \tparam T The target reference type. + * \return The casted result. + */ + template <typename T> + static TVM_FFI_INLINE T ConvertAfterCheck(const Any& ref) { + if constexpr (!std::is_same_v<T, Any>) { + return TypeTraits<T>::ConvertFromAnyViewAfterCheck(&(ref.data_)); + } else { + return ref; + } + } + template <typename T> + static TVM_FFI_INLINE bool CheckAny(const Any& ref) { + return TypeTraits<T>::CheckAnyView(&(ref.data_)); + } + + template <typename T> + static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const Any& ref) { + return TypeTraits<T>::GetMismatchTypeInfo(&(ref.data_)); + } +}; +} // namespace details } // namespace ffi } // namespace tvm #endif // TVM_FFI_ANY_H_ diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 65341feb83..659ee6cd83 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -77,7 +77,7 @@ typedef enum { // [Section] Static Boxed: [kTVMFFIStaticObjectBegin, kTVMFFIDynObjectBegin) kTVMFFIStaticObjectBegin = 64, kTVMFFIObject = 64, - kTVMFFIList = 65, + kTVMFFIArray = 65, kTVMFFIDict = 66, kTVMFFIError = 67, kTVMFFIFunc = 68, diff --git a/ffi/include/tvm/ffi/container/array.h b/ffi/include/tvm/ffi/container/array.h new file mode 100644 index 0000000000..2ce946a120 --- /dev/null +++ b/ffi/include/tvm/ffi/container/array.h @@ -0,0 +1,1010 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/array.h + * \brief Array type. + * + * tvm::ffi::Array<Any> is an erased type that contains list of content + */ +#ifndef TVM_FFI_CONTAINER_ARRAY_H_ +#define TVM_FFI_CONTAINER_ARRAY_H_ + +#include <tvm/ffi/any.h> +#include <tvm/ffi/container/base.h> +#include <tvm/ffi/memory.h> +#include <tvm/ffi/object.h> + +#include <algorithm> +#include <memory> +#include <type_traits> +#include <utility> +#include <vector> + +namespace tvm { +namespace ffi { + +/*! \brief array node content in array */ +class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, Any> { + public: + /*! \return The size of the array */ + size_t size() const { return this->size_; } + + /*! + * \brief Read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const Any at(int64_t i) const { return this->operator[](i); } + + /*! \return begin constant iterator */ + const Any* begin() const { return static_cast<Any*>(InplaceArrayBase::AddressOf(0)); } + + /*! \return end constant iterator */ + const Any* end() const { return begin() + size_; } + + /*! \brief Release reference to all the elements */ + void clear() { ShrinkBy(size_); } + + /*! + * \brief Set i-th element of the array in-place + * \param i The index + * \param item The value to be set + */ + void SetItem(int64_t i, Any item) { this->operator[](i) = std::move(item); } + + /*! + * \brief Constructs a container and copy from another + * \param cap The capacity of the container + * \param from Source of the copy + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr<ArrayNode> CopyFrom(int64_t cap, ArrayNode* from) { + int64_t size = from->size_; + if (size > cap) { + TVM_FFI_THROW(ValueError) << "not enough capacity"; + } + ObjectPtr<ArrayNode> p = ArrayNode::Empty(cap); + Any* write = p->MutableBegin(); + Any* read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) Any(*read++); + } + return p; + } + + /*! + * \brief Constructs a container and move from another + * \param cap The capacity of the container + * \param from Source of the move + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr<ArrayNode> MoveFrom(int64_t cap, ArrayNode* from) { + int64_t size = from->size_; + if (size > cap) { + TVM_FFI_THROW(RuntimeError) << "not enough capacity"; + } + ObjectPtr<ArrayNode> p = ArrayNode::Empty(cap); + Any* write = p->MutableBegin(); + Any* read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) Any(std::move(*read++)); + } + from->size_ = 0; + return p; + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr<ArrayNode> CreateRepeated(int64_t n, const Any& val) { + ObjectPtr<ArrayNode> p = ArrayNode::Empty(n); + Any* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < n; ++i) { + new (itr++) Any(val); + } + return p; + } + + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIArray; + static constexpr const char* _type_key = "object.Array"; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ArrayNode, Object); + + private: + /*! \return Size of initialized memory, used by InplaceArrayBase. */ + size_t GetSize() const { return this->size_; } + + /*! \return begin mutable iterator */ + Any* MutableBegin() const { return static_cast<Any*>(InplaceArrayBase::AddressOf(0)); } + + /*! \return end mutable iterator */ + Any* MutableEnd() const { return MutableBegin() + size_; } + + /*! + * \brief Create an ArrayNode with the given capacity. + * \param n Required capacity + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr<ArrayNode> Empty(int64_t n = kInitSize) { + TVM_FFI_ICHECK_GE(n, 0); + ObjectPtr<ArrayNode> p = make_inplace_array_object<ArrayNode, Any>(n); + p->capacity_ = n; + p->size_ = 0; + return p; + } + + /*! + * \brief Inplace-initialize the elements starting idx from [first, last) + * \param idx The starting point + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return Self + */ + template <typename IterType> + ArrayNode* InitRange(int64_t idx, IterType first, IterType last) { + Any* itr = MutableBegin() + idx; + for (; first != last; ++first) { + Any ref = *first; + new (itr++) Any(std::move(ref)); + } + return this; + } + + /*! + * \brief Move elements from right to left, requires src_begin > dst + * \param dst Destination + * \param src_begin The start point of copy (inclusive) + * \param src_end The end point of copy (exclusive) + * \return Self + */ + ArrayNode* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { + Any* from = MutableBegin() + src_begin; + Any* to = MutableBegin() + dst; + while (src_begin++ != src_end) { + *to++ = std::move(*from++); + } + return this; + } + + /*! + * \brief Move elements from left to right, requires src_begin < dst + * \param dst Destination + * \param src_begin The start point of move (inclusive) + * \param src_end The end point of move (exclusive) + * \return Self + */ + ArrayNode* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { + Any* from = MutableBegin() + src_end; + Any* to = MutableBegin() + (src_end - src_begin + dst); + while (src_begin++ != src_end) { + *--to = std::move(*--from); + } + return this; + } + + /*! + * \brief Enlarges the size of the array + * \param delta Size enlarged, should be positive + * \param val Default value + * \return Self + */ + ArrayNode* EnlargeBy(int64_t delta, const Any& val = Any()) { + Any* itr = MutableEnd(); + while (delta-- > 0) { + new (itr++) Any(val); + ++size_; + } + return this; + } + + /*! + * \brief Shrinks the size of the array + * \param delta Size shrinked, should be positive + * \return Self + */ + ArrayNode* ShrinkBy(int64_t delta) { + Any* itr = MutableEnd(); + while (delta-- > 0) { + (--itr)->Any::~Any(); + --size_; + } + return this; + } + + /*! \brief Number of elements used */ + int64_t size_; + + /*! \brief Number of elements allocated */ + int64_t capacity_; + + /*! \brief Initial size of ArrayNode */ + static constexpr int64_t kInitSize = 4; + + /*! \brief Expansion factor of the Array */ + static constexpr int64_t kIncFactor = 2; + + // CRTP parent class + friend InplaceArrayBase<ArrayNode, Any>; + + // Reference class + template <typename, typename> + friend class Array; + + // To specialize make_object<ArrayNode> + friend ObjectPtr<ArrayNode> make_object<>(); +}; + +/*! \brief Helper struct for type-checking + * + * is_valid_iterator<T,IterType>::value will be true if IterType can + * be dereferenced into a type that can be stored in an Array<T>, and + * false otherwise. + */ +template <typename T, typename IterType> +struct is_valid_iterator + : std::bool_constant< + std::is_same_v< + T, std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<IterType>())>>> || + std::is_base_of_v< + T, std::remove_cv_t<std::remove_reference_t<decltype(*std::declval<IterType>())>>>> { +}; + +template <typename T, typename IterType> +struct is_valid_iterator<Optional<T>, IterType> : is_valid_iterator<T, IterType> {}; + +template <typename T, typename IterType> +inline constexpr bool is_valid_iterator_v = is_valid_iterator<T, IterType>::value; + +/*! + * \brief Array, container representing a contiguous sequence of ObjectRefs. + * + * Array implements in-place copy-on-write semantics. + * + * As in typical copy-on-write, a method which would typically mutate the array + * instead opaquely copies the underlying container, and then acts on its copy. + * + * If the array has reference count equal to one, we directly update the + * container in place without copying. This is optimization is sound because + * when the reference count is equal to one this reference is guranteed to be + * the sole pointer to the container. + * + * + * operator[] only provides const access, use Set to mutate the content. + * \tparam T The content Value type, must be compatible with tvm::ffi::Any + */ +template <typename T, + typename = typename std::enable_if_t<std::is_same_v<T, Any> || TypeTraits<T>::enabled>> +class Array : public ObjectRef { + public: + using value_type = T; + // constructors + /*! + * \brief default constructor + */ + Array() { data_ = ArrayNode::Empty(); } + + /*! + * \brief move constructor + * \param other source + */ + Array(Array<T>&& other) : ObjectRef() { // NOLINT(*) + data_ = std::move(other.data_); + } + + /*! + * \brief copy constructor + * \param other source + */ + Array(const Array<T>& other) : ObjectRef() { // NOLINT(*) + data_ = other.data_; + } + + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Array(ObjectPtr<Object> n) : ObjectRef(n) {} + + /*! + * \brief Constructor from iterator + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template <typename IterType> + Array(IterType first, IterType last) { + static_assert(is_valid_iterator_v<T, IterType>, + "IterType cannot be inserted into a tvm::Array<T>"); + Assign(first, last); + } + + /*! + * \brief constructor from initializer list + * \param init The initializer list + */ + Array(std::initializer_list<T> init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief constructor from vector + * \param init The vector + */ + Array(const std::vector<T>& init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + */ + explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); } + + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array<T>& operator=(Array<T>&& other) { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array<T>& operator=(const Array<T>& other) { + data_ = other.data_; + return *this; + } + + public: + // iterators + struct ValueConverter { + using ResultType = T; + static T convert(const Any& n) { return details::AnyUnsafe::ConvertAfterCheck<T>(n); } + }; + + using iterator = IterAdapter<ValueConverter, const Any*>; + using reverse_iterator = ReverseIterAdapter<ValueConverter, const Any*>; + + /*! \return begin iterator */ + iterator begin() const { return iterator(GetArrayNode()->begin()); } + + /*! \return end iterator */ + iterator end() const { return iterator(GetArrayNode()->end()); } + + /*! \return rbegin iterator */ + reverse_iterator rbegin() const { + // ArrayNode::end() is never nullptr + return reverse_iterator(GetArrayNode()->end() - 1); + } + + /*! \return rend iterator */ + reverse_iterator rend() const { + // ArrayNode::begin() is never nullptr + return reverse_iterator(GetArrayNode()->begin() - 1); + } + + public: + // const methods in std::vector + /*! + * \brief Immutably read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const T operator[](int64_t i) const { + ArrayNode* p = GetArrayNode(); + if (p == nullptr) { + TVM_FFI_THROW(IndexError) << "cannot index a null array"; + } + if (i < 0 || i >= p->size_) { + TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; + } + return details::AnyUnsafe::ConvertAfterCheck<T>(*(p->begin() + i)); + } + + /*! \return The size of the array */ + size_t size() const { + ArrayNode* p = GetArrayNode(); + return p == nullptr ? 0 : GetArrayNode()->size_; + } + + /*! \return The capacity of the array */ + size_t capacity() const { + ArrayNode* p = GetArrayNode(); + return p == nullptr ? 0 : GetArrayNode()->capacity_; + } + + /*! \return Whether array is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the array */ + const T front() const { + ArrayNode* p = GetArrayNode(); + if (p == nullptr || p->size_ == 0) { + TVM_FFI_THROW(IndexError) << "cannot index a empty array"; + } + return details::AnyUnsafe::ConvertAfterCheck<T>(*(p->begin())); + } + + /*! \return The last element of the array */ + const T back() const { + ArrayNode* p = GetArrayNode(); + if (p == nullptr || p->size_ == 0) { + TVM_FFI_THROW(IndexError) << "cannot index a empty array"; + } + return details::AnyUnsafe::ConvertAfterCheck<T>(*(p->end() - 1)); + } + + public: + // mutation in std::vector, implements copy-on-write + /*! + * \brief push a new item to the back of the list + * \param item The item to be pushed. + */ + void push_back(const T& item) { + ArrayNode* p = CopyOnWrite(1); + p->EmplaceInit(p->size_++, item); + } + + /*! + * \brief Insert an element into the given position + * \param position An iterator pointing to the insertion point + * \param val The element to insert + */ + void insert(iterator position, const T& val) { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot insert a null array"; + } + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + auto addr = CopyOnWrite(1) // + ->EnlargeBy(1) // + ->MoveElementsRight(idx + 1, idx, size) // + ->MutableBegin(); + new (addr + idx) Any(val); + } + + /*! + * \brief Insert a range of elements into the given position + * \param position An iterator pointing to the insertion point + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + template <typename IterType> + void insert(iterator position, IterType first, IterType last) { + static_assert(is_valid_iterator_v<T, IterType>, + "IterType cannot be inserted into a tvm::Array<T>"); + + if (first == last) { + return; + } + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot insert a null array"; + } + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + int64_t numel = std::distance(first, last); + CopyOnWrite(numel) + ->EnlargeBy(numel) + ->MoveElementsRight(idx + numel, idx, size) + ->InitRange(idx, first, last); + } + + /*! \brief Remove the last item of the list */ + void pop_back() { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot pop_back a null array"; + } + int64_t size = GetArrayNode()->size_; + if (size == 0) { + TVM_FFI_THROW(RuntimeError) << "cannot pop_back an empty array"; + } + CopyOnWrite()->ShrinkBy(1); + } + + /*! + * \brief Erase an element on the given position + * \param position An iterator pointing to the element to be erased + */ + void erase(iterator position) { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot erase a null array"; + } + int64_t st = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + if (st < 0 || st >= size) { + TVM_FFI_THROW(RuntimeError) << "cannot erase at index " << st << ", because Array size is " + << size; + } + CopyOnWrite() // + ->MoveElementsLeft(st, st + 1, size) // + ->ShrinkBy(1); + } + + /*! + * \brief Erase a given range of elements + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + void erase(iterator first, iterator last) { + if (first == last) { + return; + } + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot erase a null array"; + } + int64_t size = GetArrayNode()->size_; + int64_t st = std::distance(begin(), first); + int64_t ed = std::distance(begin(), last); + if (st >= ed) { + TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ", " << ed << ")"; + } + if (st < 0 || st > size || ed < 0 || ed > size) { + TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ", " << ed << ")" + << ", because array size is " << size; + } + CopyOnWrite() // + ->MoveElementsLeft(st, ed, size) // + ->ShrinkBy(ed - st); + } + + /*! + * \brief Resize the array. + * \param n The new size. + */ + void resize(int64_t n) { + if (n < 0) { + TVM_FFI_THROW(ValueError) << "cannot resize an Array to negative size"; + } + if (data_ == nullptr) { + SwitchContainer(n); + return; + } + int64_t size = GetArrayNode()->size_; + if (size < n) { + CopyOnWrite(n - size)->EnlargeBy(n - size); + } else if (size > n) { + CopyOnWrite()->ShrinkBy(size - n); + } + } + + /*! + * \brief Make sure the list has the capacity of at least n + * \param n lower bound of the capacity + */ + void reserve(int64_t n) { + if (data_ == nullptr || n > GetArrayNode()->capacity_) { + SwitchContainer(n); + } + } + + /*! \brief Release reference to all the elements */ + void clear() { + if (data_ != nullptr) { + ArrayNode* p = CopyOnWrite(); + p->clear(); + } + } + + template <typename... Args> + static size_t CalcCapacityImpl() { + return 0; + } + + template <typename... Args> + static size_t CalcCapacityImpl(Array<T> value, Args... args) { + return value.size() + CalcCapacityImpl(args...); + } + + template <typename... Args> + static size_t CalcCapacityImpl(T value, Args... args) { + return 1 + CalcCapacityImpl(args...); + } + + template <typename... Args> + static void AgregateImpl(Array<T>& dest) {} // NOLINT(*) + + template <typename... Args> + static void AgregateImpl(Array<T>& dest, Array<T> value, Args... args) { // NOLINT(*) + dest.insert(dest.end(), value.begin(), value.end()); + AgregateImpl(dest, args...); + } + + template <typename... Args> + static void AgregateImpl(Array<T>& dest, T value, Args... args) { // NOLINT(*) + dest.push_back(value); + AgregateImpl(dest, args...); + } + + public: + // Array's own methods + + /*! + * \brief set i-th element of the array. + * \param i The index + * \param value The value to be setted. + */ + void Set(int64_t i, T value) { + ArrayNode* p = this->CopyOnWrite(); + if (i < 0 || i >= p->size_) { + TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; + } + *(p->MutableBegin() + i) = std::move(value); + } + + /*! \return The underlying ArrayNode */ + ArrayNode* GetArrayNode() const { return static_cast<ArrayNode*>(data_.get()); } + + /*! + * \brief Helper function to apply a map function onto the array. + * + * \param fmap The transformation function T -> U. + * + * \tparam F The type of the mutation function. + * + * \tparam U The type of the returned array, inferred from the + * return type of F. If overridden by the user, must be something + * that is convertible from the return type of F. + * + * \note This function performs copy on write optimization. If + * `fmap` returns an object of type `T`, and all elements of the + * array are mapped to themselves, then the returned array will be + * the same as the original, and reference counts of the elements in + * the array will not be incremented. + * + * \return The transformed array. + */ + template <typename F, typename U = std::invoke_result_t<F, T>> + Array<U> Map(F fmap) const { + return Array<U>(MapHelper(data_, fmap)); + } + + /*! + * \brief Helper function to apply fmutate to mutate an array. + * \param fmutate The transformation function T -> T. + * \tparam F the type of the mutation function. + * \note This function performs copy on write optimization. + */ + template <typename F, typename = std::enable_if_t<std::is_same_v<T, std::invoke_result_t<F, T>>>> + void MutateByApply(F fmutate) { + data_ = MapHelper(std::move(data_), fmutate); + } + + /*! + * \brief reset the array to content from iterator. + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template <typename IterType> + void Assign(IterType first, IterType last) { + int64_t cap = std::distance(first, last); + if (cap < 0) { + TVM_FFI_THROW(ValueError) << "cannot construct an Array of negative size"; + } + ArrayNode* p = GetArrayNode(); + if (p != nullptr && data_.unique() && p->capacity_ >= cap) { + // do not have to make new space + p->clear(); + } else { + // create new space + data_ = ArrayNode::Empty(cap); + p = GetArrayNode(); + } + // To ensure exception safety, size is only incremented after the initialization succeeds + Any* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { + new (itr) Any(*first); + } + } + + /*! + * \brief Copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + ArrayNode* CopyOnWrite() { + if (data_ == nullptr) { + return SwitchContainer(ArrayNode::kInitSize); + } + if (!data_.unique()) { + return SwitchContainer(capacity()); + } + return static_cast<ArrayNode*>(data_.get()); + } + + /*! \brief specify container node */ + using ContainerType = ArrayNode; + + /*! + * \brief Agregate arguments into a single Array<T> + * \param args sequence of T or Array<T> elements + * \return Agregated Array<T> + */ + template <typename... Args> + static Array<T> Agregate(Args... args) { + Array<T> result; + result.reserve(CalcCapacityImpl(args...)); + AgregateImpl(result, args...); + return result; + } + + private: + /*! + * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. + * \param reserve_extra Number of extra slots needed + * \return ArrayNode pointer to the unique copy + */ + ArrayNode* CopyOnWrite(int64_t reserve_extra) { + ArrayNode* p = GetArrayNode(); + if (p == nullptr) { + // necessary to get around the constexpr address issue before c++17 + const int64_t kInitSize = ArrayNode::kInitSize; + return SwitchContainer(std::max(kInitSize, reserve_extra)); + } + if (p->capacity_ >= p->size_ + reserve_extra) { + return CopyOnWrite(); + } + int64_t cap = p->capacity_ * ArrayNode::kIncFactor; + cap = std::max(cap, p->size_ + reserve_extra); + return SwitchContainer(cap); + } + + /*! + * \brief Move or copy the ArrayNode to new address with the given capacity + * \param capacity The capacity requirement of the new address + */ + ArrayNode* SwitchContainer(int64_t capacity) { + if (data_ == nullptr) { + data_ = ArrayNode::Empty(capacity); + } else if (data_.unique()) { + data_ = ArrayNode::MoveFrom(capacity, GetArrayNode()); + } else { + data_ = ArrayNode::CopyFrom(capacity, GetArrayNode()); + } + return static_cast<ArrayNode*>(data_.get()); + } + + /*! \brief Helper method for mutate/map + * + * A helper function used internally by both `Array::Map` and + * `Array::MutateInPlace`. Given an array of data, apply the + * mapping function to each element, returning the collected array. + * Applies both mutate-in-place and copy-on-write optimizations, if + * possible. + * + * \param data A pointer to the ArrayNode containing input data. + * Passed by value to allow for mutate-in-place optimizations. + * + * \param fmap The mapping function + * + * \tparam F The type of the mutation function. + * + * \tparam U The output type of the mutation function. Inferred + * from the callable type given. Must inherit from ObjectRef. + * + * \return The mapped array. Depending on whether mutate-in-place + * or copy-on-write optimizations were applicable, may be the same + * underlying array as the `data` parameter. + */ + template <typename F, typename U = std::invoke_result_t<F, T>> + static ObjectPtr<Object> MapHelper(ObjectPtr<Object> data, F fmap) { + if (data == nullptr) { + return nullptr; + } + + TVM_FFI_ICHECK(data->IsInstance<ArrayNode>()); + + constexpr bool is_same_output_type = std::is_same_v<T, U>; + + if constexpr (is_same_output_type) { + if (data.unique()) { + // Mutate-in-place path. Only allowed if the output type U is + // the same as type T, we have a mutable this*, and there are + // no other shared copies of the array. + auto arr = static_cast<ArrayNode*>(data.get()); + for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) { + T mapped = fmap(details::AnyUnsafe::ConvertAfterCheck<T>(std::move(*it))); + *it = std::move(mapped); + } + return data; + } + } + + constexpr bool compatible_types = is_valid_iterator_v<T, U*> || is_valid_iterator_v<U, T*>; + + ObjectPtr<ArrayNode> output = nullptr; + auto arr = static_cast<ArrayNode*>(data.get()); + + auto it = arr->begin(); + if constexpr (compatible_types) { + // Copy-on-write path, if the output Array<U> might be + // represented by the same underlying array as the existing + // Array<T>. Typically, this is for functions that map `T` to + // `T`, but can also apply to functions that map `T` to + // `Optional<T>`, or that map `T` to a subclass or superclass of + // `T`. + bool all_identical = true; + for (; it != arr->end(); it++) { + U mapped = fmap(details::AnyUnsafe::ConvertAfterCheck<T>(*it)); + if (!mapped.same_as(*it)) { + // At least one mapped element is different than the + // original. Therefore, prepare the output array, + // consisting of any previous elements that had mapped to + // themselves (if any), and the element that didn't map to + // itself. + // + // We cannot use `U()` as the default object, as `U` may be + // a non-nullable type. Since the default `Any()` + // will be overwritten before returning, all objects will be + // of type `U` for the calling scope. + all_identical = false; + output = ArrayNode::CreateRepeated(arr->size(), Any()); + output->InitRange(0, arr->begin(), it); + output->SetItem(it - arr->begin(), std::move(mapped)); + it++; + break; + } + } + if (all_identical) { + return data; + } + } else { + // Path for incompatible types. The constexpr check for + // compatible types isn't strictly necessary, as the first + // mapped.same_as(*it) would return false, but we might as well + // avoid it altogether. + // + // We cannot use `U()` as the default object, as `U` may be a + // non-nullable type. Since the default `Any()` will be + // overwritten before returning, all objects will be of type `U` + // for the calling scope. + output = ArrayNode::CreateRepeated(arr->size(), Any()); + } + + // Normal path for incompatible types, or post-copy path for + // copy-on-write instances. + // + // If the types are incompatible, then at this point `output` is + // empty, and `it` points to the first element of the input. + // + // If the types were compatible, then at this point `output` + // contains zero or more elements that mapped to themselves + // followed by the first element that does not map to itself, and + // `it` points to the element just after the first element that + // does not map to itself. Because at least one element has been + // changed, we no longer have the opportunity to avoid a copy, so + // we don't need to check the result. + // + // In both cases, `it` points to the next element to be processed, + // so we can either start or resume the iteration from that point, + // with no further checks on the result. + for (; it != arr->end(); it++) { + U mapped = fmap(details::AnyUnsafe::ConvertAfterCheck<T>(*it)); + output->SetItem(it - arr->begin(), std::move(mapped)); + } + + return output; + } +}; + +/*! + * \brief Concat two Arrays. + * \param lhs first Array to be concatenated. + * \param rhs second Array to be concatenated. + * \return The concatenated Array. Original Arrays are kept unchanged. + */ +template <typename T, + typename = typename std::enable_if_t<std::is_same_v<T, Any> || TypeTraits<T>::enabled>> +inline Array<T> Concat(Array<T> lhs, const Array<T>& rhs) { + for (const auto& x : rhs) { + lhs.push_back(x); + } + return std::move(lhs); +} + +// Specialize make_object<ArrayNode> to make sure it is correct. +template <> +inline ObjectPtr<ArrayNode> make_object() { + return ArrayNode::Empty(); +} + +// Traits for Array +template <typename T> +inline constexpr bool use_default_type_traits_v<Array<T>> = false; + +template <typename T> +struct TypeTraits<Array<T>> : public TypeTraitsBase { + static TVM_FFI_INLINE void ConvertToAnyView(const Array<T>& src, TVMFFIAny* result) { + TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetTVMFFIObjectPtrFromObjectRef(src); + result->type_index = obj_ptr->type_index; + result->v_obj = obj_ptr; + } + + static TVM_FFI_INLINE void MoveToAny(Array<T> src, TVMFFIAny* result) { + TVMFFIObject* obj_ptr = details::ObjectUnsafe::MoveTVMFFIObjectPtrFromObjectRef(&src); + result->type_index = obj_ptr->type_index; + result->v_obj = obj_ptr; + } + + static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIArray) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + if constexpr (!std::is_same_v<T, Any>) { + const ArrayNode* n = reinterpret_cast<const ArrayNode*>(src->v_obj); + for (size_t i = 0; i < n->size(); i++) { + const Any& p = (*n)[i]; + if (!details::AnyUnsafe::CheckAny<T>(p)) { + return "Array[index " + std::to_string(i) + ": " + + details::AnyUnsafe::GetMismatchTypeInfo<T>(p) + "]"; + } + } + } + TVM_FFI_THROW(InternalError) << "Cannot reach here"; + TVM_FFI_UNREACHABLE(); + } + + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + // for now allow null array, TODO: revisit this + if (src->type_index == TypeIndex::kTVMFFINone) return true; + if (src->type_index != TypeIndex::kTVMFFIArray) return false; + if constexpr (std::is_same_v<T, Any>) { + return true; + } else { + const ArrayNode* n = reinterpret_cast<const ArrayNode*>(src->v_obj); + for (size_t i = 0; i < n->size(); i++) { + const Any& p = (*n)[i]; + if (!details::AnyUnsafe::CheckAny<T>(p)) return false; + } + return true; + } + } + + static TVM_FFI_INLINE Array<T> ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFINone) return Array<T>(nullptr); + return Array<T>(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj)); + } + + static TVM_FFI_INLINE std::optional<Array<T>> TryConvertFromAnyView(const TVMFFIAny* src) { + if (CheckAnyView(src)) return ConvertFromAnyViewAfterCheck(src); + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return "Array<" + TypeTraits<T>::TypeStr() + ">"; } +}; + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_ARRAY_H_ diff --git a/ffi/include/tvm/ffi/container/base.h b/ffi/include/tvm/ffi/container/base.h new file mode 100644 index 0000000000..bfeeacd59b --- /dev/null +++ b/ffi/include/tvm/ffi/container/base.h @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/base.h + * \brief Base utilities for common POD(plain old data) container types. + */ +#ifndef TVM_FFI_CONTAINER_BASE_H_ +#define TVM_FFI_CONTAINER_BASE_H_ + +#include <tvm/ffi/memory.h> +#include <tvm/ffi/object.h> + +#include <algorithm> +#include <initializer_list> +#include <utility> + +namespace tvm { +namespace ffi { +/*! + * \brief Base template for classes with array like memory layout. + * + * It provides general methods to access the memory. The memory + * layout is ArrayType + [ElemType]. The alignment of ArrayType + * and ElemType is handled by the memory allocator. + * + * \tparam ArrayType The array header type, contains object specific metadata. + * \tparam ElemType The type of objects stored in the array right after + * ArrayType. + * + * \code + * // Example usage of the template to define a simple array wrapper + * class ArrayNode : public InplaceArrayBase<ArrayNode, Elem> { + * public: + * // Wrap EmplaceInit to initialize the elements + * template <typename Iterator> + * void Init(Iterator begin, Iterator end) { + * size_t num_elems = std::distance(begin, end); + * auto it = begin; + * this->size = 0; + * for (size_t i = 0; i < num_elems; ++i) { + * InplaceArrayBase::EmplaceInit(i, *it++); + * this->size++; + * } + * } + * } + * + * void test_function() { + * vector<Elem> fields; + * auto ptr = make_inplace_array_object<ArrayObj, Elem>(fields.size()); + * ptr->Init(fields.begin(), fields.end()); + * + * // Access the 0th element in the array. + * assert(ptr->operator[](0) == fields[0]); + * } + * + * \endcode + */ +template <typename ArrayType, typename ElemType> +class InplaceArrayBase { + public: + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Const reference to ElemType at the index. + */ + const ElemType& operator[](size_t idx) const { + size_t size = Self()->GetSize(); + if (idx > size) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size; + } + return *(reinterpret_cast<ElemType*>(AddressOf(idx))); + } + + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Reference to ElemType at the index. + */ + ElemType& operator[](size_t idx) { + size_t size = Self()->GetSize(); + if (idx > size) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size; + } + return *(reinterpret_cast<ElemType*>(AddressOf(idx))); + } + + /*! + * \brief Destroy the Inplace Array Base object + */ + ~InplaceArrayBase() { + if (!(std::is_standard_layout<ElemType>::value && std::is_trivial<ElemType>::value)) { + size_t size = Self()->GetSize(); + for (size_t i = 0; i < size; ++i) { + ElemType* fp = reinterpret_cast<ElemType*>(AddressOf(i)); + fp->ElemType::~ElemType(); + } + } + } + + protected: + /*! + * \brief Construct a value in place with the arguments. + * + * \tparam Args Type parameters of the arguments. + * \param idx Index of the element. + * \param args Arguments to construct the new value. + * + * \note Please make sure ArrayType::GetSize returns 0 before first call of + * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. + */ + template <typename... Args> + void EmplaceInit(size_t idx, Args&&... args) { + void* field_ptr = AddressOf(idx); + new (field_ptr) ElemType(std::forward<Args>(args)...); + } + + /*! + * \brief Return the self object for the array. + * + * \return Pointer to ArrayType. + */ + inline ArrayType* Self() const { + return static_cast<ArrayType*>(const_cast<InplaceArrayBase*>(this)); + } + + /*! + * \brief Return the raw pointer to the element at idx. + * + * \param idx The index of the element. + * \return Raw pointer to the element. + */ + void* AddressOf(size_t idx) const { + static_assert( + alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, + "The size and alignment of ArrayType should respect " + "ElemType's alignment."); + + size_t kDataStart = sizeof(ArrayType); + ArrayType* self = Self(); + char* data_start = reinterpret_cast<char*>(self) + kDataStart; + return data_start + idx * sizeof(ElemType); + } +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template <typename Converter, typename TIter> +class IterAdapter { + public: + using difference_type = typename std::iterator_traits<TIter>::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; + using iterator_category = typename std::iterator_traits<TIter>::iterator_category; + + explicit IterAdapter(TIter iter) : iter_(iter) {} + IterAdapter& operator++() { + ++iter_; + return *this; + } + IterAdapter& operator--() { + --iter_; + return *this; + } + IterAdapter operator++(int) { + IterAdapter copy = *this; + ++iter_; + return copy; + } + IterAdapter operator--(int) { + IterAdapter copy = *this; + --iter_; + return copy; + } + + IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } + + IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } + + template <typename T = IterAdapter> + typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value, + typename T::difference_type>::type inline + operator-(const IterAdapter& rhs) const { + return iter_ - rhs.iter_; + } + + bool operator==(IterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(IterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template <typename Converter, typename TIter> +class ReverseIterAdapter { + public: + using difference_type = typename std::iterator_traits<TIter>::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) + using iterator_category = typename std::iterator_traits<TIter>::iterator_category; + + explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} + ReverseIterAdapter& operator++() { + --iter_; + return *this; + } + ReverseIterAdapter& operator--() { + ++iter_; + return *this; + } + ReverseIterAdapter operator++(int) { + ReverseIterAdapter copy = *this; + --iter_; + return copy; + } + ReverseIterAdapter operator--(int) { + ReverseIterAdapter copy = *this; + ++iter_; + return copy; + } + ReverseIterAdapter operator+(difference_type offset) const { + return ReverseIterAdapter(iter_ - offset); + } + + template <typename T = ReverseIterAdapter> + typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value, + typename T::difference_type>::type inline + operator-(const ReverseIterAdapter& rhs) const { + return rhs.iter_ - iter_; + } + + bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_BASE_H_ diff --git a/ffi/include/tvm/ffi/container/optional.h b/ffi/include/tvm/ffi/container/optional.h new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h index 03cd542d91..d0e5995fa5 100644 --- a/ffi/include/tvm/ffi/memory.h +++ b/ffi/include/tvm/ffi/memory.h @@ -95,9 +95,10 @@ class ObjAllocatorBase { "make_inplace_array can only be used to create Object"); ArrayType* ptr = Handler::New(static_cast<Derived*>(this), num_elems, std::forward<Args>(args)...); - ptr->type_index_ = ArrayType::RuntimeTypeIndex(); - ptr->deleter_ = Handler::Deleter(); - return ObjectPtr<ArrayType>(ptr); + TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); + ffi_ptr->type_index = ArrayType::RuntimeTypeIndex(); + ffi_ptr->deleter = Handler::Deleter(); + return details::ObjectUnsafe::ObjectPtrFromUnowned<ArrayType>(ptr); } }; @@ -181,7 +182,7 @@ class SimpleObjAllocator : public ObjAllocatorBase<SimpleObjAllocator> { static FObjectDeleter Deleter() { return Deleter_; } private: - static void Deleter_(Object* objptr) { + static void Deleter_(void* objptr) { // NOTE: this is important to cast back to ArrayType* // because objptr and tptr may not be the same // depending on how sub-class allocates the space. diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index b7a0d1a7c5..13de74fa07 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -319,7 +319,7 @@ class ObjectPtr { * \param data The data pointer */ explicit ObjectPtr(Object* data) : data_(data) { - if (data != nullptr) { + if (data_ != nullptr) { data_->IncRef(); } } diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index e1d453fdba..be46639642 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -34,6 +34,39 @@ namespace tvm { namespace ffi { +/*! + * \brief Get type key from type index + * \param type_index The input type index + * \return the type key + */ +inline std::string TypeIndex2TypeKey(int32_t type_index) { + switch (type_index) { + case TypeIndex::kTVMFFINone: + return "None"; + case TypeIndex::kTVMFFIInt: + return "int"; + case TypeIndex::kTVMFFIFloat: + return "float"; + case TypeIndex::kTVMFFIOpaquePtr: + return "void*"; + case TypeIndex::kTVMFFIDataType: + return "DataType"; + case TypeIndex::kTVMFFIDevice: + return "Device"; + case TypeIndex::kTVMFFIRawStr: + return "const char*"; + default: { + TVM_FFI_ICHECK_GE(type_index, TypeIndex::kTVMFFIStaticObjectBegin) + << "Uknown type_index=" << type_index; +#if TVM_FFI_ALLOW_DYN_TYPE + const TypeInfo* type_info = details::ObjectGetTypeInfo(type_index); + return type_info->type_key; +#else + return "object.Object"; +#endif + } + } +} /*! * \brief TypeTraits that specifies the conversion behavior from/to FFI Any. * @@ -59,17 +92,36 @@ struct TypeTraits { template <typename T> using TypeTraitsNoCR = TypeTraits<std::remove_const_t<std::remove_reference_t<T>>>; -// None -template <> -struct TypeTraits<std::nullptr_t> { +template <typename T> +inline constexpr bool use_default_type_traits_v = true; + +struct TypeTraitsBase { static constexpr bool enabled = true; + // get mismatched type when result mismatches the trait. + // this function is called after TryConvertFromAnyView fails + // to get more detailed type information in runtime + // especially when the error involves nested container type + static TVM_FFI_INLINE std::string GetMismatchTypeInfo(const TVMFFIAny* source) { + return TypeIndex2TypeKey(source->type_index); + } +}; + +// None +template <> +struct TypeTraits<std::nullptr_t> : public TypeTraitsBase { static TVM_FFI_INLINE void ConvertToAnyView(const std::nullptr_t&, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFINone; + // invariant: the pointer field also equals nullptr + // this will simplify the recovery of nullable object from the any + result->v_int64 = 0; } static TVM_FFI_INLINE void MoveToAny(std::nullptr_t, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFINone; + // invariant: the pointer field also equals nullptr + // this will simplify the recovery of nullable object from the any + result->v_int64 = 0; } static TVM_FFI_INLINE std::optional<std::nullptr_t> TryConvertFromAnyView(const TVMFFIAny* src) { @@ -79,14 +131,20 @@ struct TypeTraits<std::nullptr_t> { return std::nullopt; } + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFINone; + } + + static TVM_FFI_INLINE std::nullptr_t ConvertFromAnyViewAfterCheck(const TVMFFIAny*) { + return nullptr; + } + static TVM_FFI_INLINE std::string TypeStr() { return "None"; } }; // Integer POD values template <typename Int> -struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> { - static constexpr bool enabled = true; - +struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> : public TypeTraitsBase { static TVM_FFI_INLINE void ConvertToAnyView(const Int& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIInt; result->v_int64 = static_cast<int64_t>(src); @@ -103,14 +161,21 @@ struct TypeTraits<Int, std::enable_if_t<std::is_integral_v<Int>>> { return std::nullopt; } + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIInt; + } + + static TVM_FFI_INLINE int ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + return static_cast<Int>(src->v_int64); + } + static TVM_FFI_INLINE std::string TypeStr() { return "int"; } }; // Float POD values template <typename Float> -struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> { - static constexpr bool enabled = true; - +struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> + : public TypeTraitsBase { static TVM_FFI_INLINE void ConvertToAnyView(const Float& src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIFloat; result->v_float64 = static_cast<double>(src); @@ -129,14 +194,24 @@ struct TypeTraits<Float, std::enable_if_t<std::is_floating_point_v<Float>>> { return std::nullopt; } + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIFloat || src->type_index == TypeIndex::kTVMFFIInt; + } + + static TVM_FFI_INLINE Float ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIFloat) { + return static_cast<Float>(src->v_float64); + } else { + return static_cast<Float>(src->v_int64); + } + } + static TVM_FFI_INLINE std::string TypeStr() { return "float"; } }; // void* template <> -struct TypeTraits<void*> { - static constexpr bool enabled = true; - +struct TypeTraits<void*> : public TypeTraitsBase { static TVM_FFI_INLINE void ConvertToAnyView(void* src, TVMFFIAny* result) { result->type_index = TypeIndex::kTVMFFIOpaquePtr; result->v_ptr = src; @@ -150,19 +225,31 @@ struct TypeTraits<void*> { if (src->type_index == TypeIndex::kTVMFFIOpaquePtr) { return std::make_optional<void*>(src->v_ptr); } + if (src->type_index == TypeIndex::kTVMFFINone) { + return std::make_optional<void*>(nullptr); + } return std::nullopt; } + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + return src->type_index == TypeIndex::kTVMFFIOpaquePtr || + src->type_index == TypeIndex::kTVMFFINone; + } + + static TVM_FFI_INLINE void* ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + return src->v_ptr; + } + static TVM_FFI_INLINE std::string TypeStr() { return "void*"; } }; -// Traits for object +// Traits for ObjectRef template <typename TObjRef> -struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef>>> { +struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef> && + use_default_type_traits_v<TObjRef>>> + : public TypeTraitsBase { using ContainerType = typename TObjRef::ContainerType; - static constexpr bool enabled = true; - static TVM_FFI_INLINE void ConvertToAnyView(const TObjRef& src, TVMFFIAny* result) { TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetTVMFFIObjectPtrFromObjectRef(src); result->type_index = obj_ptr->type_index; @@ -175,19 +262,27 @@ struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef result->v_obj = obj_ptr; } + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + return (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && + details::IsObjectInstance<ContainerType>(src->type_index)) || + (src->type_index == kTVMFFINone && TObjRef::_type_is_nullable); + } + + static TVM_FFI_INLINE TObjRef ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + if constexpr (TObjRef::_type_is_nullable) { + if (src->type_index == kTVMFFINone) return TObjRef(nullptr); + } + return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj)); + } + static TVM_FFI_INLINE std::optional<TObjRef> TryConvertFromAnyView(const TVMFFIAny* src) { if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { -#if TVM_FFI_ALLOW_DYN_TYPE if (details::IsObjectInstance<ContainerType>(src->type_index)) { return TObjRef(details::ObjectUnsafe::ObjectPtrFromUnowned<Object>(src->v_obj)); } -#else - TVM_FFI_THROW(RuntimeError) - << "Converting to object requires `TVM_FFI_ALLOW_DYN_TYPE` to be on". -#endif - } else if (src->type_index == kTVMFFINone) { - if (!TObjRef::_type_is_nullable) return std::nullopt; - return TObjRef(ObjectPtr<Object>()); + } + if constexpr (TObjRef::_type_is_nullable) { + if (src->type_index == kTVMFFINone) return TObjRef(nullptr); } return std::nullopt; } @@ -195,11 +290,9 @@ struct TypeTraits<TObjRef, std::enable_if_t<std::is_base_of_v<ObjectRef, TObjRef static TVM_FFI_INLINE std::string TypeStr() { return ContainerType::_type_key; } }; -// Traits for object +// Traits for ObjectPtr template <typename T> -struct TypeTraits<ObjectPtr<T>> { - static constexpr bool enabled = true; - +struct TypeTraits<ObjectPtr<T>> : public TypeTraitsBase { static TVM_FFI_INLINE void ConvertToAnyView(const ObjectPtr<T>& src, TVMFFIAny* result) { TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetTVMFFIObjectPtrFromObjectPtr(src); result->type_index = obj_ptr->type_index; @@ -212,28 +305,27 @@ struct TypeTraits<ObjectPtr<T>> { result->v_obj = obj_ptr; } + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + return src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && + details::IsObjectInstance<T>(src->type_index); + } + + static TVM_FFI_INLINE ObjectPtr<T> ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + return details::ObjectUnsafe::ObjectPtrFromUnowned<T>(src->v_obj); + } + static TVM_FFI_INLINE std::optional<ObjectPtr<T>> TryConvertFromAnyView(const TVMFFIAny* src) { - if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { -#if TVM_FFI_ALLOW_DYN_TYPE - if (details::IsObjectInstance<T>(src->type_index)) { - return details::ObjectUnsafe::ObjectPtrFromUnowned<T>(src->v_obj); - } -#else - TVM_FFI_THROW(RuntimeError) - << "Converting to object requires `TVM_FFI_ALLOW_DYN_TYPE` to be on". -#endif - } + if (CheckAnyView(src)) return ConvertFromAnyViewAfterCheck(src); return std::nullopt; } static TVM_FFI_INLINE std::string TypeStr() { return T::_type_key; } }; -// Traits for object +// Traits for weak pointer of object template <typename TObject> -struct TypeTraits<const TObject*, std::enable_if_t<std::is_base_of_v<Object, TObject>>> { - static constexpr bool enabled = true; - +struct TypeTraits<const TObject*, std::enable_if_t<std::is_base_of_v<Object, TObject>>> + : public TypeTraitsBase { static TVM_FFI_INLINE void ConvertToAnyView(const TObject* src, TVMFFIAny* result) { TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); result->type_index = obj_ptr->type_index; @@ -244,59 +336,27 @@ struct TypeTraits<const TObject*, std::enable_if_t<std::is_base_of_v<Object, TOb TVMFFIObject* obj_ptr = details::ObjectUnsafe::GetHeader(src); result->type_index = obj_ptr->type_index; result->v_obj = obj_ptr; + // needs to increase ref because original weak ptr do not own the code details::ObjectUnsafe::IncRefObjectInAny(result); } + static TVM_FFI_INLINE bool CheckAnyView(const TVMFFIAny* src) { + return src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && + details::IsObjectInstance<TObject>(src->type_index); + } + + static TVM_FFI_INLINE const TObject* ConvertFromAnyViewAfterCheck(const TVMFFIAny* src) { + return reinterpret_cast<const TObject*>(src->v_obj); + } + static TVM_FFI_INLINE std::optional<const TObject*> TryConvertFromAnyView(const TVMFFIAny* src) { - if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { -#if TVM_FFI_ALLOW_DYN_TYPE - if (details::IsObjectInstance<TObject>(src->type_index)) { - return reinterpret_cast<const TObject*>(src->v_obj); - } -#else - TVM_FFI_THROW(RuntimeError) - << "Converting to object requires `TVM_FFI_ALLOW_DYN_TYPE` to be on". -#endif - } + if (CheckAnyView(src)) return ConvertFromAnyViewAfterCheck(src); return std::nullopt; } static TVM_FFI_INLINE std::string TypeStr() { return TObject::_type_key; } }; -/*! - * \brief Get type key from type index - * \param type_index The input type index - * \return the type key - */ -inline std::string TypeIndex2TypeKey(int32_t type_index) { - switch (type_index) { - case TypeIndex::kTVMFFINone: - return "None"; - case TypeIndex::kTVMFFIInt: - return "int"; - case TypeIndex::kTVMFFIFloat: - return "float"; - case TypeIndex::kTVMFFIOpaquePtr: - return "void*"; - case TypeIndex::kTVMFFIDataType: - return "DataType"; - case TypeIndex::kTVMFFIDevice: - return "Device"; - case TypeIndex::kTVMFFIRawStr: - return "const char*"; - default: { - TVM_FFI_ICHECK_GE(type_index, TypeIndex::kTVMFFIStaticObjectBegin) - << "Uknown type_index=" << type_index; -#if TVM_FFI_ALLOW_DYN_TYPE - const TypeInfo* type_info = details::ObjectGetTypeInfo(type_index); - return type_info->type_key; -#else - return "object.Object"; -#endif - } - } -} } // namespace ffi } // namespace tvm #endif // TVM_FFI_TYPE_TRAITS_H_ diff --git a/ffi/tests/example/test_array.cc b/ffi/tests/example/test_array.cc new file mode 100644 index 0000000000..c5ba557425 --- /dev/null +++ b/ffi/tests/example/test_array.cc @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <gtest/gtest.h> +#include <tvm/ffi/container/array.h> + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +TEST(Array, Basic) { + Array<TInt> arr = {TInt(11), TInt(12)}; + TInt v1 = arr[0]; + EXPECT_EQ(v1->value, 11); + EXPECT_EQ(v1.use_count(), 2); + EXPECT_EQ(arr[1]->value, 12); +} + +TEST(Array, COWSet) { + Array<TInt> arr = {TInt(11), TInt(12)}; + Array<TInt> arr2 = arr; + EXPECT_EQ(arr.use_count(), 2); + arr.Set(1, TInt(13)); + EXPECT_EQ(arr.use_count(), 1); + EXPECT_EQ(arr[1]->value, 13); + EXPECT_EQ(arr2[1]->value, 12); +} + +TEST(Array, AnyConvertCheck) { + Array<Any> arr = {11.1, 1}; + EXPECT_EQ(arr[1].operator int(), 1); + + AnyView view0 = arr; + Array<double> arr1 = view0; + EXPECT_EQ(arr1[0], 11.1); + EXPECT_EQ(arr1[1], 1.0); + + Any any1 = arr; + + EXPECT_THROW( + { + try { + [[maybe_unused]] Array<int> arr2 = any1; + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("Cannot convert from type `Array[index 0: float]` to `Array<int>`"), + std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); + + Array<Array<TNumber>> arr_nested = {{}, {TInt(1), TFloat(2)}}; + any1 = arr_nested; + Array<Array<TNumber>> arr1_nested = any1; + EXPECT_EQ(arr1_nested.use_count(), 3); + + EXPECT_THROW( + { + try { + [[maybe_unused]] Array<Array<int>> arr2 = any1; + } catch (const Error& error) { + EXPECT_EQ(error->kind, "TypeError"); + std::string what = error.what(); + EXPECT_NE(what.find("`Array[index 1: Array[index 0: test.Int]]` to `Array<Array<int>>`"), + std::string::npos); + throw; + } + }, + ::tvm::ffi::Error); +} + +TEST(Array, MutateInPlaceForUniqueReference) { + TInt x(1); + Array<TInt> arr{x, x}; + EXPECT_TRUE(arr.unique()); + auto* before = arr.get(); + + arr.MutateByApply([](TInt) { return TInt(2); }); + auto* after = arr.get(); + EXPECT_EQ(before, after); +} + +TEST(Array, CopyWhenMutatingNonUniqueReference) { + TInt x(1); + Array<TInt> arr{x, x}; + Array<TInt> arr2 = arr; + + EXPECT_TRUE(!arr.unique()); + auto* before = arr.get(); + + arr.MutateByApply([](TInt) { return TInt(2); }); + auto* after = arr.get(); + EXPECT_NE(before, after); +} + +TEST(Array, Map) { + // Basic functionality + TInt x(1), y(1); + Array<TInt> var_arr{x, y}; + Array<TNumber> expr_arr = var_arr.Map([](TInt var) -> TNumber { return TFloat(var->value + 1); }); + + EXPECT_NE(var_arr.get(), expr_arr.get()); + EXPECT_TRUE(expr_arr[0]->IsInstance<TFloatObj>()); + EXPECT_TRUE(expr_arr[1]->IsInstance<TFloatObj>()); +} + +TEST(Array, Iterator) { + Array<int> array{1, 2, 3}; + std::vector<int> vector(array.begin(), array.end()); + EXPECT_EQ(vector[1], 2); +} + +TEST(Array, PushPop) { + Array<int> a; + std::vector<int> b; + for (int i = 0; i < 10; ++i) { + a.push_back(i); + b.push_back(i); + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), b.size()); + int n = a.size(); + for (int j = 0; j < n; ++j) { + ASSERT_EQ(a[j], b[j]); + } + } + for (int i = 9; i >= 0; --i) { + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), b.size()); + a.pop_back(); + b.pop_back(); + int n = a.size(); + for (int j = 0; j < n; ++j) { + ASSERT_EQ(a[j], b[j]); + } + } + ASSERT_EQ(a.empty(), true); +} + +TEST(Array, ResizeReserveClear) { + for (size_t n = 0; n < 10; ++n) { + Array<int> a; + Array<int> b; + a.resize(n); + b.reserve(n); + ASSERT_EQ(a.size(), n); + ASSERT_GE(a.capacity(), n); + a.clear(); + b.clear(); + ASSERT_EQ(a.size(), 0); + ASSERT_EQ(b.size(), 0); + } +} + +TEST(Array, InsertErase) { + Array<int> a; + std::vector<int> b; + for (int n = 1; n <= 10; ++n) { + a.insert(a.end(), n); + b.insert(b.end(), n); + for (int pos = 0; pos <= n; ++pos) { + a.insert(a.begin() + pos, pos); + b.insert(b.begin() + pos, pos); + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n + 1); + ASSERT_EQ(b.size(), n + 1); + for (int k = 0; k <= n; ++k) { + ASSERT_EQ(a[k], b[k]); + } + a.erase(a.begin() + pos); + b.erase(b.begin() + pos); + } + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n); + } +} + +TEST(Array, InsertEraseRange) { + Array<int> range_a{-1, -2, -3, -4}; + std::vector<int> range_b{-1, -2, -3, -4}; + Array<int> a; + std::vector<int> b; + + static_assert(std::is_same_v<decltype(*range_a.begin()), int>); + for (size_t n = 1; n <= 10; ++n) { + a.insert(a.end(), n); + b.insert(b.end(), n); + for (size_t pos = 0; pos <= n; ++pos) { + a.insert(a.begin() + pos, range_a.begin(), range_a.end()); + b.insert(b.begin() + pos, range_b.begin(), range_b.end()); + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n + range_a.size()); + ASSERT_EQ(b.size(), n + range_b.size()); + size_t m = n + range_a.size(); + for (size_t k = 0; k < m; ++k) { + ASSERT_EQ(a[k], b[k]); + } + a.erase(a.begin() + pos, a.begin() + pos + range_a.size()); + b.erase(b.begin() + pos, b.begin() + pos + range_b.size()); + } + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n); + } +} + +} // namespace
