tqchen commented on code in PR #443: URL: https://github.com/apache/tvm-ffi/pull/443#discussion_r2804205384
########## include/tvm/ffi/container/seq_base.h: ########## @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/seq_base.h + * \brief Base class for sequence containers (Array, List). + */ +#ifndef TVM_FFI_CONTAINER_SEQ_BASE_H_ +#define TVM_FFI_CONTAINER_SEQ_BASE_H_ + +#include <tvm/ffi/any.h> +#include <tvm/ffi/object.h> + +#include <algorithm> +#include <cstddef> +#include <iterator> +#include <utility> + +namespace tvm { +namespace ffi { + +/*! + * \brief Base class for sequence containers (ArrayObj, ListObj). + * + * SeqBaseObj is transparent to the FFI type system (no type index), + * following the same pattern as BytesObjBase. + */ +class SeqBaseObj : public Object, protected TVMFFISeqCell { + public: + SeqBaseObj() { + data = nullptr; + TVMFFISeqCell::size = 0; + TVMFFISeqCell::capacity = 0; + data_deleter = nullptr; + } + + ~SeqBaseObj() { + Any* begin = MutableBegin(); + for (int64_t i = 0; i < TVMFFISeqCell::size; ++i) { + (begin + i)->Any::~Any(); + } + if (data_deleter != nullptr) { + data_deleter(data); + } + } + + /*! \return The size of the sequence */ + size_t size() const { return static_cast<size_t>(TVMFFISeqCell::size); } + + /*! \return The capacity of the sequence */ + size_t capacity() const { return static_cast<size_t>(TVMFFISeqCell::capacity); } + + /*! \return Whether the sequence is empty */ + bool empty() const { return TVMFFISeqCell::size == 0; } + + /*! + * \brief Read i-th element from the sequence. + * \param i The index + * \return the i-th element. + */ + const Any& at(int64_t i) const { return this->operator[](i); } + + /*! + * \brief Read i-th element from the sequence. + * \param i The index + * \return the i-th element. + */ + const Any& operator[](int64_t i) const { + if (i < 0 || i >= TVMFFISeqCell::size) { + TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << TVMFFISeqCell::size; + } + return static_cast<Any*>(data)[i]; + } + + /*! \return The first element */ + const Any& front() const { + if (TVMFFISeqCell::size == 0) { + TVM_FFI_THROW(IndexError) << "front() on empty sequence"; + } + return static_cast<Any*>(data)[0]; + } + + /*! \return The last element */ + const Any& back() const { + if (TVMFFISeqCell::size == 0) { + TVM_FFI_THROW(IndexError) << "back() on empty sequence"; + } + return static_cast<Any*>(data)[TVMFFISeqCell::size - 1]; + } + + /*! \return begin constant iterator */ + const Any* begin() const { return static_cast<Any*>(data); } + + /*! \return end constant iterator */ + const Any* end() const { return begin() + TVMFFISeqCell::size; } + + /*! \brief Release reference to all the elements */ + void clear() { + Any* itr = MutableEnd(); + while (TVMFFISeqCell::size > 0) { + (--itr)->Any::~Any(); + --TVMFFISeqCell::size; + } + } + + /*! + * \brief Set i-th element of the sequence in-place + * \param i The index + * \param item The value to be set + */ + void SetItem(int64_t i, Any item) { + if (i < 0 || i >= TVMFFISeqCell::size) { + TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << TVMFFISeqCell::size; + } + static_cast<Any*>(data)[i] = std::move(item); + } + + /*! \brief Remove the last element */ + void pop_back() { + if (TVMFFISeqCell::size == 0) { + TVM_FFI_THROW(IndexError) << "pop_back on empty sequence"; + } + ShrinkBy(1); + } + + /*! + * \brief Erase element at position idx + * \param idx The index to erase + */ + void erase(int64_t idx) { + if (idx < 0 || idx >= TVMFFISeqCell::size) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << TVMFFISeqCell::size; + } + MoveElementsLeft(idx, idx + 1, TVMFFISeqCell::size); + ShrinkBy(1); + } + + /*! + * \brief Erase elements in half-open range [first, last) + * \param first Start index (inclusive) + * \param last End index (exclusive) + */ + void erase(int64_t first, int64_t last) { + if (first == last) return; + if (first < 0 || last > TVMFFISeqCell::size || first >= last) { + TVM_FFI_THROW(IndexError) << "Erase range [" << first << ", " << last << ") out of bounds " + << TVMFFISeqCell::size; + } + MoveElementsLeft(first, last, TVMFFISeqCell::size); + ShrinkBy(last - first); + } + + /*! + * \brief Insert element at position idx + * \param idx The index to insert at + * \param item The value to insert + * \note Caller must ensure capacity >= size + 1 + */ + void insert(int64_t idx, Any item) { + int64_t sz = TVMFFISeqCell::size; + if (idx < 0 || idx > sz) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds [0, " << sz << "]"; + } + EnlargeBy(1); + MoveElementsRight(idx + 1, idx, sz); + MutableBegin()[idx] = std::move(item); + } + + /*! + * \brief Insert elements from iterator range at position idx + * \param idx The index to insert at + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \note Caller must ensure capacity >= size + distance(first, last) + */ + template <typename IterType> + void insert(int64_t idx, IterType first, IterType last) { + int64_t count = std::distance(first, last); + if (count == 0) return; + int64_t sz = TVMFFISeqCell::size; + if (idx < 0 || idx > sz) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds [0, " << sz << "]"; + } + EnlargeBy(count); + MoveElementsRight(idx + count, idx, sz); + Any* dst = MutableBegin() + idx; + for (; first != last; ++first, ++dst) { + *dst = Any(*first); + } + } + + /*! \brief Reverse the elements in-place */ + void Reverse() { std::reverse(MutableBegin(), MutableBegin() + TVMFFISeqCell::size); } + + /*! + * \brief Resize the sequence + * \param n The new size + * \note Caller must ensure capacity >= n when growing + */ + void resize(int64_t n) { + if (n < 0) { + TVM_FFI_THROW(ValueError) << "Cannot resize to negative size"; + } + int64_t old_size = TVMFFISeqCell::size; + if (old_size < n) { + EnlargeBy(n - old_size); + } else if (old_size > n) { + ShrinkBy(old_size - n); + } + } + + protected: + /// \cond Doxygen_Suppress + Any* MutableBegin() const { return static_cast<Any*>(this->data); } + + Any* MutableEnd() const { return MutableBegin() + TVMFFISeqCell::size; } + + template <typename... Args> + void EmplaceInit(size_t idx, Args&&... args) { + Any* itr = MutableBegin() + idx; + new (itr) Any(std::forward<Args>(args)...); + } + + void EnlargeBy(int64_t delta, const Any& val = Any()) { + Any* itr = MutableEnd(); + while (delta-- > 0) { + new (itr++) Any(val); + ++TVMFFISeqCell::size; + } + } + + void ShrinkBy(int64_t delta) { + Any* itr = MutableEnd(); + while (delta-- > 0) { + (--itr)->Any::~Any(); + --TVMFFISeqCell::size; + } + } + + void MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { + Any* begin = MutableBegin(); + std::move(begin + src_begin, begin + src_end, begin + dst); + } + + void MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { + Any* begin = MutableBegin(); + std::move_backward(begin + src_begin, begin + src_end, begin + dst + (src_end - src_begin)); + } + /// \endcond +}; + +template <typename Derived, typename SeqRef, typename T> +struct SeqTypeTraitsBase : public ObjectRefTypeTraitsBase<SeqRef> { Review Comment: document the expected derived properties ########## include/tvm/ffi/container/list.h: ########## @@ -0,0 +1,523 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/list.h + * \brief Mutable list type. + * + * tvm::ffi::List<Any> is an erased mutable sequence container. + */ +#ifndef TVM_FFI_CONTAINER_LIST_H_ +#define TVM_FFI_CONTAINER_LIST_H_ + +#include <tvm/ffi/any.h> +#include <tvm/ffi/container/array.h> +#include <tvm/ffi/container/seq_base.h> +#include <tvm/ffi/object.h> + +#include <algorithm> +#include <initializer_list> +#include <sstream> +#include <string> +#include <type_traits> +#include <utility> +#include <vector> + +namespace tvm { +namespace ffi { + +/*! \brief List node content in list */ +class ListObj : public SeqBaseObj { + public: + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + * \return Ref-counted ListObj requested + */ + static ObjectPtr<ListObj> CreateRepeated(int64_t n, const Any& val) { + ObjectPtr<ListObj> p = ListObj::Empty(n); + Any* itr = p->MutableBegin(); + for (int64_t& i = p->TVMFFISeqCell::size = 0; i < n; ++i) { + new (itr++) Any(val); + } + return p; + } + + /// \cond Doxygen_Suppress + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIList; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIList, ListObj, Object); + /// \endcond + + private: + /*! + * \brief Ensure the list has at least n slots. + * \param n The lower bound of required capacity. + */ + void Reserve(int64_t n) { + if (n <= TVMFFISeqCell::capacity) { + return; + } + Any* old_data = MutableBegin(); + Any* new_data = static_cast<Any*>(::operator new(sizeof(Any) * static_cast<size_t>(n))); + for (int64_t i = 0; i < TVMFFISeqCell::size; ++i) { Review Comment: This is not leaksafe when constructor of AnyMove causes exception, might worth cross check and ensure any construct of move from any is no except ########## include/tvm/ffi/container/list.h: ########## @@ -0,0 +1,810 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/list.h + * \brief Mutable list type. + * + * tvm::ffi::List<Any> is an erased mutable sequence container. + */ +#ifndef TVM_FFI_CONTAINER_LIST_H_ +#define TVM_FFI_CONTAINER_LIST_H_ + +#include <tvm/ffi/any.h> +#include <tvm/ffi/container/array.h> +#include <tvm/ffi/object.h> + +#include <algorithm> +#include <initializer_list> +#include <sstream> +#include <string> +#include <type_traits> +#include <utility> +#include <vector> + +namespace tvm { +namespace ffi { + +/*! \brief List node content in list */ +class ListObj : public Object, protected TVMFFISeqCell { + public: + ~ListObj() { + Any* begin = MutableBegin(); + for (int64_t i = 0; i < length; ++i) { + (begin + i)->Any::~Any(); + } + if (data_deleter != nullptr) { + data_deleter(data); + } + } + + /*! \return The size of the list */ + size_t size() const { return static_cast<size_t>(length); } + + /*! + * \brief Read i-th element from list. + * \param i The index + * \return the i-th element. + */ + const Any& at(int64_t i) const { return this->operator[](i); } + + /*! + * \brief Read i-th element from list. + * \param i The index + * \return the i-th element. + */ + const Any& operator[](int64_t i) const { + if (i < 0 || i >= length) { + TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << length; + } + return MutableBegin()[i]; + } + + /*! \return begin constant iterator */ + const Any* begin() const { return MutableBegin(); } + + /*! \return end constant iterator */ + const Any* end() const { return MutableBegin() + length; } + + /*! \brief Release reference to all the elements */ + void clear() { ShrinkBy(length); } + + /*! \brief Reverse the elements in-place */ + void Reverse() { std::reverse(MutableBegin(), MutableBegin() + length); } + + /*! + * \brief Set i-th element of the list in-place + * \param i The index + * \param item The value to be set + */ + void SetItem(int64_t i, Any item) { + if (i < 0 || i >= length) { + TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << length; + } + MutableBegin()[i] = std::move(item); + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + * \return Ref-counted ListObj requested + */ + static ObjectPtr<ListObj> CreateRepeated(int64_t n, const Any& val) { + ObjectPtr<ListObj> p = ListObj::Empty(n); + Any* itr = p->MutableBegin(); + for (int64_t& i = p->length = 0; i < n; ++i) { + new (itr++) Any(val); + } + return p; + } + + /// \cond Doxygen_Suppress + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIList; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIList, ListObj, Object); + /// \endcond + + private: + /*! \return begin mutable iterator */ + Any* MutableBegin() const { return static_cast<Any*>(this->data); } + + /*! \return end mutable iterator */ + Any* MutableEnd() const { return MutableBegin() + length; } + + /*! + * \brief Emplace a new element at the given index + * \param idx The index of the element. + * \param args The arguments to construct the new element + */ + template <typename... Args> + void EmplaceInit(size_t idx, Args&&... args) { + Any* itr = MutableBegin() + idx; + new (itr) Any(std::forward<Args>(args)...); + } + + /*! + * \brief Assign elements into existing slots from [first, last). + * \param idx The starting point + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \note Callers must ensure target slots are already initialized. + * \return Self + */ + template <typename IterType> + ListObj* AssignRange(int64_t idx, IterType first, IterType last) { + Any* itr = MutableBegin() + idx; + for (; first != last; ++first) { + *itr++ = Any(*first); + } + return this; + } + + /*! + * \brief Move elements from right to left, requires src_begin > dst + * \param dst Destination + * \param src_begin The start point of copy (inclusive) + * \param src_end The end point of copy (exclusive) + * \return Self + */ + ListObj* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { + Any* begin = MutableBegin(); + std::move(begin + src_begin, begin + src_end, begin + dst); + return this; + } + + /*! + * \brief Move elements from left to right, requires src_begin < dst + * \param dst Destination + * \param src_begin The start point of move (inclusive) + * \param src_end The end point of move (exclusive) + * \return Self + */ + ListObj* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { + Any* begin = MutableBegin(); + std::move_backward(begin + src_begin, begin + src_end, begin + dst + (src_end - src_begin)); + return this; + } + + /*! + * \brief Enlarges the size of the list + * \param delta Size enlarged, should be positive + * \param val Default value + * \return Self + */ + ListObj* EnlargeBy(int64_t delta, const Any& val = Any()) { + Any* itr = MutableEnd(); + while (delta-- > 0) { + new (itr++) Any(val); + ++length; + } + return this; + } + + /*! + * \brief Shrinks the size of the list + * \param delta Size shrinked, should be positive + * \return Self + */ + ListObj* ShrinkBy(int64_t delta) { + Any* itr = MutableEnd(); + while (delta-- > 0) { + (--itr)->Any::~Any(); + --length; + } + return this; + } + + /*! + * \brief Ensure the list has at least n slots. + * \param n The lower bound of required capacity. + */ + void Reserve(int64_t n) { + if (n <= capacity) { + return; + } + Any* old_data = MutableBegin(); + Any* new_data = static_cast<Any*>(::operator new(sizeof(Any) * static_cast<size_t>(n))); + for (int64_t i = 0; i < length; ++i) { + new (new_data + i) Any(std::move(old_data[i])); + } + for (int64_t j = 0; j < length; ++j) { + (old_data + j)->Any::~Any(); + } + data_deleter(data); + data = new_data; + capacity = n; + } + + /*! + * \brief Create an empty ListObj with the given capacity. + * \param n Required capacity + * \return Ref-counted ListObj requested + */ + static ObjectPtr<ListObj> Empty(int64_t n = kInitSize) { + if (n < 0) { + TVM_FFI_THROW(ValueError) << "cannot construct a List of negative size"; + } + ObjectPtr<ListObj> p = make_object<ListObj>(); + p->capacity = n; + p->length = 0; + p->data = n == 0 ? nullptr : static_cast<void*>(::operator new(sizeof(Any) * n)); + p->data_deleter = RawDataDeleter; + return p; + } + + static void RawDataDeleter(void* data) { ::operator delete(data); } + + /*! \brief Initial size of ListObj */ + static constexpr int64_t kInitSize = 4; + /*! \brief Expansion factor of the List */ + static constexpr int64_t kIncFactor = 2; + + template <typename, typename> + friend class List; + + template <typename, typename> + friend struct TypeTraits; +}; + +/*! + * \brief List, container representing a mutable contiguous sequence of ObjectRefs. + * + * Unlike Array, List is mutable and does not implement copy-on-write. + * Mutations happen directly on the underlying shared ListObj. + * + * \note Thread safety: List is **not** thread-safe. Concurrent reads and writes + * from multiple threads require external synchronization. + * + * \warning Because List elements are stored as `Any` (which may hold `ObjectRef` + * pointers), it is possible to create reference cycles (e.g. a List that + * contains itself). Such cycles will **not** be collected by the reference- + * counting mechanism alone; avoid them in long-lived data structures. + * + * \tparam T The content Value type, must be compatible with tvm::ffi::Any + */ +template <typename T, typename = typename std::enable_if_t<details::storage_enabled_v<T>>> +class List : public ObjectRef { + public: + /*! \brief The value type of the list */ + using value_type = T; + + /*! \brief Construct a List with UnsafeInit */ + explicit List(UnsafeInit tag) : ObjectRef(tag) {} + /*! \brief default constructor */ + List() { data_ = ListObj::Empty(); } // NOLINT(modernize-use-equals-default) + /*! \brief Move constructor */ + List(List<T>&& other) // NOLINT(google-explicit-constructor) + : ObjectRef(std::move(other.data_)) {} + /*! \brief Copy constructor */ + List(const List<T>& other) : ObjectRef(other.data_) {} // NOLINT(google-explicit-constructor) + + /*! + * \brief Constructor from another list + * \tparam U The value type of the other list + */ + template <typename U, typename = std::enable_if_t<details::type_contains_v<T, U>>> + List(List<U>&& other) // NOLINT(google-explicit-constructor) + : ObjectRef(std::move(other.data_)) {} + + /*! + * \brief Constructor from another list + * \tparam U The value type of the other list + */ + template <typename U, typename = std::enable_if_t<details::type_contains_v<T, U>>> + List(const List<U>& other) // NOLINT(google-explicit-constructor) + : ObjectRef(other.data_) {} + + /*! + * \brief Move assignment from another list. + * \param other The other list. + */ + TVM_FFI_INLINE List<T>& operator=(List<T>&& other) { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief Assignment from another list. + * \param other The other list. + */ + TVM_FFI_INLINE List<T>& operator=(const List<T>& other) { + data_ = other.data_; + return *this; + } + + /*! + * \brief Move assignment from another list. + * \param other The other list. + * \tparam U The value type of the other list. + */ + template <typename U, typename = std::enable_if_t<details::type_contains_v<T, U>>> + TVM_FFI_INLINE List<T>& operator=(List<U>&& other) { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief Assignment from another list. + * \param other The other list. + * \tparam U The value type of the other list. + */ + template <typename U, typename = std::enable_if_t<details::type_contains_v<T, U>>> + TVM_FFI_INLINE List<T>& operator=(const List<U>& other) { + data_ = other.data_; + return *this; + } + + /*! \brief Constructor from pointer */ + explicit List(ObjectPtr<Object> n) : ObjectRef(std::move(n)) {} + + /*! + * \brief Constructor from iterator + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template <typename IterType> + List(IterType first, IterType last) { // NOLINT(performance-unnecessary-value-param) + static_assert(is_valid_iterator_v<T, IterType>, + "IterType cannot be inserted into a tvm::List<T>"); + Assign(first, last); + } + + /*! \brief constructor from initializer list */ + List(std::initializer_list<T> init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! \brief constructor from vector */ + List(const std::vector<T>& init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + */ + explicit List(const size_t n, const T& val) { data_ = ListObj::CreateRepeated(n, val); } + + public: + // iterators + /// \cond Doxygen_Suppress + struct ValueConverter { + using ResultType = T; + static T convert(const Any& n) { return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(n); } + }; + /// \endcond + + /*! \brief The iterator type of the list */ + using iterator = details::IterAdapter<ValueConverter, const Any*>; + /*! \brief The reverse iterator type of the list */ + using reverse_iterator = details::ReverseIterAdapter<ValueConverter, const Any*>; + + /*! \return begin iterator */ + iterator begin() const { return iterator(GetListObj()->begin()); } + /*! \return end iterator */ + iterator end() const { return iterator(GetListObj()->end()); } + /*! \return rbegin iterator */ + reverse_iterator rbegin() const { return reverse_iterator(GetListObj()->end() - 1); } + /*! \return rend iterator */ + reverse_iterator rend() const { return reverse_iterator(GetListObj()->begin() - 1); } + + public: + // const methods in std::vector + /*! + * \brief Immutably read i-th element from list. + * \param i The index + * \return the i-th element. + */ + const T operator[](int64_t i) const { + ListObj* p = GetListObj(); + if (p == nullptr) { + TVM_FFI_THROW(IndexError) << "cannot index a null list"; + } + if (i < 0 || i >= p->length) { + TVM_FFI_THROW(IndexError) << "indexing " << i << " on a list of size " << p->length; + } + return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*(p->begin() + i)); + } + + /*! \return The size of the list */ + size_t size() const { + ListObj* p = GetListObj(); + return p == nullptr ? 0 : static_cast<size_t>(p->length); + } + + /*! \return The capacity of the list */ + size_t capacity() const { + ListObj* p = GetListObj(); + return p == nullptr ? 0 : static_cast<size_t>(p->capacity); + } + + /*! \return Whether list is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the list */ + const T front() const { + ListObj* p = GetListObj(); + if (p == nullptr || p->length == 0) { + TVM_FFI_THROW(IndexError) << "cannot index an empty list"; + } + return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*(p->begin())); + } + + /*! \return The last element of the list */ + const T back() const { + ListObj* p = GetListObj(); + if (p == nullptr || p->length == 0) { + TVM_FFI_THROW(IndexError) << "cannot index an empty list"; + } + return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*(p->end() - 1)); + } + + public: + // mutation in std::vector + /*! + * \brief push a new item to the back of the list + * \param item The item to be pushed. + */ + void push_back(const T& item) { + ListObj* p = EnsureCapacity(1); + p->EmplaceInit(p->length++, item); + } + + /*! + * \brief Emplace a new element at the back of the list + * \param args The arguments to construct the new element + */ + template <typename... Args> + void emplace_back(Args&&... args) { + ListObj* p = EnsureCapacity(1); + p->EmplaceInit(p->length++, std::forward<Args>(args)...); + } + + /*! + * \brief Insert an element into the given position + * \param position An iterator pointing to the insertion point + * \param val The element to insert + */ + void insert(iterator position, const T& val) { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot insert to a null list"; + } + Any item(val); + int64_t idx = std::distance(begin(), position); + int64_t size = GetListObj()->length; + if (idx < 0 || idx > size) { + TVM_FFI_THROW(IndexError) << "cannot insert at index " << idx << ", because List size is " + << size; + } + auto addr = EnsureCapacity(1) // + ->EnlargeBy(1) // + ->MoveElementsRight(idx + 1, idx, size) // + ->MutableBegin(); + addr[idx] = std::move(item); + } + + /*! + * \brief Insert a range of elements into the given position + * \param position An iterator pointing to the insertion point + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + template <typename IterType> + void insert(iterator position, IterType first, IterType last) { + static_assert(is_valid_iterator_v<T, IterType>, + "IterType cannot be inserted into a tvm::List<T>"); + if (first == last) { + return; + } + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot insert to a null list"; + } + int64_t idx = std::distance(begin(), position); + int64_t size = GetListObj()->length; + int64_t numel = std::distance(first, last); + if (idx < 0 || idx > size) { + TVM_FFI_THROW(IndexError) << "cannot insert at index " << idx << ", because List size is " + << size; + } + EnsureCapacity(numel) + ->EnlargeBy(numel) + ->MoveElementsRight(idx + numel, idx, size) + ->AssignRange(idx, first, last); + } + + /*! \brief Remove the last item of the list */ + void pop_back() { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot pop_back a null list"; + } + int64_t size = GetListObj()->length; + if (size == 0) { + TVM_FFI_THROW(RuntimeError) << "cannot pop_back an empty list"; + } + GetListObj()->ShrinkBy(1); + } + + /*! + * \brief Erase an element on the given position + * \param position An iterator pointing to the element to be erased + */ + void erase(iterator position) { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot erase a null list"; + } + int64_t st = std::distance(begin(), position); + int64_t size = GetListObj()->length; + if (st < 0 || st >= size) { + TVM_FFI_THROW(RuntimeError) << "cannot erase at index " << st << ", because List size is " + << size; + } + GetListObj() // + ->MoveElementsLeft(st, st + 1, size) // + ->ShrinkBy(1); + } + + /*! + * \brief Erase a given range of elements + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + void erase(iterator first, iterator last) { + if (first == last) { + return; + } + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot erase a null list"; + } + int64_t size = GetListObj()->length; + int64_t st = std::distance(begin(), first); + int64_t ed = std::distance(begin(), last); + if (st >= ed) { + TVM_FFI_THROW(IndexError) << "cannot erase list in range [" << st << ", " << ed << ")"; + } + if (st < 0 || st > size || ed < 0 || ed > size) { + TVM_FFI_THROW(IndexError) << "cannot erase list in range [" << st << ", " << ed << ")" + << ", because list size is " << size; + } + GetListObj() // + ->MoveElementsLeft(st, ed, size) // + ->ShrinkBy(ed - st); + } + + /*! + * \brief Resize the list. + * \param n The new size. + */ + void resize(int64_t n) { + if (n < 0) { + TVM_FFI_THROW(ValueError) << "cannot resize a List to negative size"; + } + ListObj* p = EnsureCapacity(std::max<int64_t>(0, n - static_cast<int64_t>(size()))); + int64_t old_size = p->length; + if (old_size < n) { + p->EnlargeBy(n - old_size); + } else if (old_size > n) { + p->ShrinkBy(old_size - n); + } + } + + /*! + * \brief Make sure the list has the capacity of at least n + * \param n lower bound of the capacity + */ + void reserve(int64_t n) { EnsureListObj()->Reserve(n); } + + /*! \brief Release reference to all the elements */ + void clear() { + if (data_ != nullptr) { + GetListObj()->clear(); + } + } + + public: + // List's own methods + /*! + * \brief set i-th element of the list. + * \param i The index + * \param value The value to be set. + */ + void Set(int64_t i, T value) { + ListObj* p = EnsureListObj(); + if (i < 0 || i >= p->length) { + TVM_FFI_THROW(IndexError) << "indexing " << i << " on a list of size " << p->length; + } + *(p->MutableBegin() + i) = std::move(value); + } + + /*! \return The underlying ListObj */ + ListObj* GetListObj() const { return static_cast<ListObj*>(data_.get()); } + + /*! + * \brief reset the list to content from iterator. + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template <typename IterType> + void Assign(IterType first, IterType last) { // NOLINT(performance-unnecessary-value-param) + int64_t cap = std::distance(first, last); + if (cap < 0) { + TVM_FFI_THROW(ValueError) << "cannot construct a List of negative size"; + } + ListObj* p = EnsureListObj(); + p->Reserve(cap); + p->clear(); + Any* itr = p->MutableBegin(); + for (int64_t& i = p->length = 0; i < cap; ++i, ++first, ++itr) { + new (itr) Any(*first); + } + } + + /*! \brief specify container node */ + using ContainerType = ListObj; + + private: + /*! + * \brief Ensure the list object exists and has room for reserve_extra new entries. + * \param reserve_extra Number of extra slots needed + * \return ListObj pointer + */ + ListObj* EnsureCapacity(int64_t reserve_extra) { + ListObj* p = EnsureListObj(); + if (p->capacity >= p->length + reserve_extra) { + return p; + } + int64_t cap = p->capacity * ListObj::kIncFactor; + cap = std::max(cap, p->length + reserve_extra); + p->Reserve(cap); + return p; + } + + /*! + * \brief Ensure this list has a container object. + * \return ListObj pointer + */ + ListObj* EnsureListObj() { + if (data_ == nullptr) { + data_ = ListObj::Empty(); + } + return static_cast<ListObj*>(data_.get()); + } + + template <typename, typename> + friend class List; +}; + +// Traits for List +template <typename T> +inline constexpr bool use_default_type_traits_v<List<T>> = false; + +template <typename T> +struct TypeTraits<List<T>> : public ObjectRefTypeTraitsBase<List<T>> { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIList; + using ObjectRefTypeTraitsBase<List<T>>::CopyFromAnyViewAfterCheck; + + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIList && src->type_index != TypeIndex::kTVMFFIArray) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + if constexpr (!std::is_same_v<T, Any>) { + auto for_each = [&](auto* n) -> std::optional<std::string> { Review Comment: move as member template function ########## src/ffi/extra/structural_equal.cc: ########## @@ -302,6 +309,33 @@ class StructEqualHandler { // NOLINTNEXTLINE(performance-unnecessary-value-param) bool CompareArray(ffi::Array<Any> lhs, ffi::Array<Any> rhs) { + return CompareSequence(std::move(lhs), std::move(rhs)); + } + + // NOLINTNEXTLINE(performance-unnecessary-value-param) + bool CompareList(ffi::List<Any> lhs, ffi::List<Any> rhs) { + return CompareSequence(std::move(lhs), std::move(rhs)); + } + + template <typename SeqType> + // NOLINTNEXTLINE(performance-unnecessary-value-param) + bool CompareSequence(SeqType lhs, SeqType rhs) { + const Object* lhs_ptr = lhs.get(); + const Object* rhs_ptr = rhs.get(); + auto pair = std::make_pair(lhs_ptr, rhs_ptr); + if (active_sequence_pairs_.count(pair)) { Review Comment: likely we don't need cycle detection in structequal to keep things simpleZ? ########## src/ffi/extra/structural_hash.cc: ########## @@ -185,11 +190,25 @@ class StructuralHashHandler { } // NOLINTNEXTLINE(performance-unnecessary-value-param) - uint64_t HashArray(Array<Any> arr) { - uint64_t hash_value = details::StableHashCombine(arr->GetTypeKeyHash(), arr.size()); - for (const auto& elem : arr) { + uint64_t HashArray(Array<Any> arr) { return HashSequence(std::move(arr)); } + + // NOLINTNEXTLINE(performance-unnecessary-value-param) + uint64_t HashList(List<Any> list) { return HashSequence(std::move(list)); } + + template <typename SeqType> + // NOLINTNEXTLINE(performance-unnecessary-value-param) + uint64_t HashSequence(SeqType seq) { + const Object* obj_ptr = seq.get(); Review Comment: consider assuming DAG in structural hash to simplify things -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
