This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 9513c2f feat: Introduce List as Mutable Sequence (#443)
9513c2f is described below
commit 9513c2f8a57f64ad7473d7cd06084719f6d5e70e
Author: Junru Shao <[email protected]>
AuthorDate: Fri Feb 13 12:36:34 2026 -0800
feat: Introduce List as Mutable Sequence (#443)
## Summary
- Introduces `List<T>`, a mutable sequence container alongside the
existing immutable `Array<T>`, with full C++ and Python support
- Extracts `SeqBaseObj` as a shared base class for `ArrayObj` and
`ListObj`, consolidating type-erased sequence operations (iteration,
element access, reverse, etc.)
- Refactors `ArrayObj` to inherit from `SeqBaseObj`, significantly
reducing code duplication while preserving copy-on-write semantics
- Adds cycle detection in serialization, structural hash, structural
equal, and JSON writer for `List` (which can form reference cycles
unlike `Array`)
- Updates `stl.h` so `TypeTraits<std::vector<T>>` accepts both
`kTVMFFIArray` and `kTVMFFIList`
## Test plan
- [x] C++ unit tests for `List` operations (`test_list.cc`:
construction, push_back, pop_back, insert, erase, resize, Set, clear,
iterator, reverse, COW-free mutation)
- [x] C++ tests for `List` structural equal/hash
(`test_structural_equal_hash.cc`: List equality, List-vs-Array type
mismatch)
- [x] C++ tests for `List` serialization (`test_serialization.cc`:
empty/single-element list, round-trip, cycle detection)
- [x] Python tests for `List` (`test_container.py`: construction,
indexing, slicing, append, insert, pop, extend, `__delitem__`,
`__setitem__`, `__contains__`, iteration, reverse, `len`, pickle
round-trip)
---
include/tvm/ffi/c_api.h | 39 +-
include/tvm/ffi/container/array.h | 365 ++++--------------
include/tvm/ffi/container/list.h | 527 ++++++++++++++++++++++++++
include/tvm/ffi/container/seq_base.h | 365 ++++++++++++++++++
include/tvm/ffi/container/tuple.h | 6 +-
include/tvm/ffi/extra/stl.h | 24 +-
include/tvm/ffi/object.h | 2 +
include/tvm/ffi/tvm_ffi.h | 1 +
python/tvm_ffi/__init__.py | 3 +-
python/tvm_ffi/_ffi_api.py | 26 ++
python/tvm_ffi/container.py | 159 +++++++-
python/tvm_ffi/cython/base.pxi | 1 +
python/tvm_ffi/cython/type_info.pxi | 1 +
python/tvm_ffi/testing/__init__.py | 1 +
python/tvm_ffi/testing/_ffi_api.py | 6 +
src/ffi/container.cc | 57 +++
src/ffi/extra/json_writer.cc | 22 +-
src/ffi/extra/serialization.cc | 47 ++-
src/ffi/extra/structural_equal.cc | 17 +
src/ffi/extra/structural_hash.cc | 17 +-
src/ffi/object.cc | 1 +
src/ffi/testing/testing.cc | 9 +
tests/cpp/extra/test_serialization.cc | 41 ++
tests/cpp/extra/test_structural_equal_hash.cc | 41 ++
tests/cpp/test_list.cc | 277 ++++++++++++++
tests/python/test_container.py | 244 ++++++++++++
tests/python/test_object.py | 1 +
27 files changed, 1964 insertions(+), 336 deletions(-)
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 30fc865..5089b6e 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -173,6 +173,8 @@ typedef enum {
* \sa TVMFFIObjectCreateOpaque
*/
kTVMFFIOpaquePyObject = 74,
+ /*! \brief List object. */
+ kTVMFFIList = 75,
//----------------------------------------------------------------
// more complex objects
//----------------------------------------------------------------
@@ -346,6 +348,7 @@ typedef struct {
/*! \brief The size of the data. */
size_t size;
} TVMFFIByteArray;
+// [TVMFFIByteArray.end]
/*!
* \brief Shape cell used in shape object following header.
@@ -356,7 +359,41 @@ typedef struct {
/*! \brief The size of the data. */
size_t size;
} TVMFFIShapeCell;
-// [TVMFFIByteArray.end]
+
+// [TVMFFISeqCell.begin]
+/*!
+ * \brief Sequence cell used by sequence-like containers.
+ *
+ * ArrayObj and ListObj both inherit from this cell.
+ */
+#ifdef __cplusplus
+struct TVMFFISeqCell {
+#else
+typedef struct {
+#endif
+ /*! \brief Data pointer to the first element of the sequence. */
+ void* data;
+ /*! \brief Number of elements used. */
+ int64_t size;
+ /*! \brief Number of elements allocated. */
+ int64_t capacity;
+ /*!
+ * \brief Optional deleter for the data buffer.
+ *
+ * When non-null, data was allocated separately from the object
+ * (e.g. ListObj heap buffer) and data_deleter is called to free it.
+ *
+ * When nullptr, data lives inside the object allocation itself
+ * (e.g. ArrayObj inplace storage via make_inplace_array_object)
+ * and is freed together with the object.
+ */
+ void (*data_deleter)(void*);
+#ifdef __cplusplus
+};
+#else
+} TVMFFISeqCell;
+#endif
+// [TVMFFISeqCell.end]
/*!
* \brief Mode to update the backtrace of the error.
diff --git a/include/tvm/ffi/container/array.h
b/include/tvm/ffi/container/array.h
index 32c1a22..2712515 100644
--- a/include/tvm/ffi/container/array.h
+++ b/include/tvm/ffi/container/array.h
@@ -28,6 +28,7 @@
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/container_details.h>
+#include <tvm/ffi/container/seq_base.h>
#include <tvm/ffi/memory.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/optional.h>
@@ -42,61 +43,18 @@ namespace tvm {
namespace ffi {
/*! \brief Array node content in array */
-class ArrayObj : public Object, public details::InplaceArrayBase<ArrayObj,
TVMFFIAny> {
- public:
- ~ArrayObj() {
- Any* begin = MutableBegin();
- for (int64_t i = 0; i < size_; ++i) {
- (begin + i)->Any::~Any();
- }
- if (data_deleter_ != nullptr) {
- data_deleter_(data_);
- }
- }
-
- /*! \return The size of the array */
- size_t size() const { return this->size_; }
-
- /*!
- * \brief Read i-th element from array.
- * \param i The index
- * \return the i-th element.
- */
- const Any& at(int64_t i) const { return this->operator[](i); }
-
- /*!
- * \brief Read i-th element from array.
- * \param i The index
- * \return the i-th element.
- */
- const Any& operator[](int64_t i) const {
- if (i < 0 || i >= size_) {
- TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << size_;
- }
- return static_cast<Any*>(data_)[i];
- }
-
- /*! \return begin constant iterator */
- const Any* begin() const { return static_cast<Any*>(data_); }
-
- /*! \return end constant iterator */
- const Any* end() const { return begin() + size_; }
-
- /*! \brief Release reference to all the elements */
- void clear() { ShrinkBy(size_); }
-
- /*!
- * \brief Set i-th element of the array in-place
- * \param i The index
- * \param item The value to be set
- */
- void SetItem(int64_t i, Any item) {
- if (i < 0 || i >= size_) {
- TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << size_;
- }
- static_cast<Any*>(data_)[i] = std::move(item);
- }
+class ArrayObj : public SeqBaseObj, public details::InplaceArrayBase<ArrayObj,
TVMFFIAny> {
+ // InplaceArrayBase<ArrayObj, TVMFFIAny>'s destructor is compiled out
because TVMFFIAny is
+ // trivial. SeqBaseObj::~SeqBaseObj() handles element destruction. Changing
the ElemType to a
+ // non-trivial type would cause double-destruction.
+ static_assert(std::is_trivially_destructible_v<TVMFFIAny>,
+ "InplaceArrayBase<ArrayObj, TVMFFIAny> must use a trivially
destructible "
+ "element type to avoid double-destruction with
SeqBaseObj::~SeqBaseObj()");
+ public:
+ // Bring SeqBaseObj names into ArrayObj scope to hide InplaceArrayBase's
versions
+ using SeqBaseObj::operator[];
+ using SeqBaseObj::EmplaceInit;
/*!
* \brief Constructs a container and copy from another
* \param cap The capacity of the container
@@ -104,7 +62,7 @@ class ArrayObj : public Object, public
details::InplaceArrayBase<ArrayObj, TVMFF
* \return Ref-counted ArrayObj requested
*/
static ObjectPtr<ArrayObj> CopyFrom(int64_t cap, ArrayObj* from) {
- int64_t size = from->size_;
+ int64_t size = from->TVMFFISeqCell::size;
if (size > cap) {
TVM_FFI_THROW(ValueError) << "Not enough capacity";
}
@@ -112,7 +70,7 @@ class ArrayObj : public Object, public
details::InplaceArrayBase<ArrayObj, TVMFF
Any* write = p->MutableBegin();
Any* read = from->MutableBegin();
// To ensure exception safety, size is only incremented after the
initialization succeeds
- for (int64_t& i = p->size_ = 0; i < size; ++i) {
+ for (int64_t& i = p->TVMFFISeqCell::size = 0; i < size; ++i) {
new (write++) Any(*read++);
}
return p;
@@ -125,7 +83,7 @@ class ArrayObj : public Object, public
details::InplaceArrayBase<ArrayObj, TVMFF
* \return Ref-counted ArrayObj requested
*/
static ObjectPtr<ArrayObj> MoveFrom(int64_t cap, ArrayObj* from) {
- int64_t size = from->size_;
+ int64_t size = from->TVMFFISeqCell::size;
if (size > cap) {
TVM_FFI_THROW(RuntimeError) << "Not enough capacity";
}
@@ -133,10 +91,10 @@ class ArrayObj : public Object, public
details::InplaceArrayBase<ArrayObj, TVMFF
Any* write = p->MutableBegin();
Any* read = from->MutableBegin();
// To ensure exception safety, size is only incremented after the
initialization succeeds
- for (int64_t& i = p->size_ = 0; i < size; ++i) {
+ for (int64_t& i = p->TVMFFISeqCell::size = 0; i < size; ++i) {
new (write++) Any(std::move(*read++));
}
- from->size_ = 0;
+ from->TVMFFISeqCell::size = 0;
return p;
}
@@ -149,7 +107,7 @@ class ArrayObj : public Object, public
details::InplaceArrayBase<ArrayObj, TVMFF
static ObjectPtr<ArrayObj> CreateRepeated(int64_t n, const Any& val) {
ObjectPtr<ArrayObj> p = ArrayObj::Empty(n);
Any* itr = p->MutableBegin();
- for (int64_t& i = p->size_ = 0; i < n; ++i) {
+ for (int64_t& i = p->TVMFFISeqCell::size = 0; i < n; ++i) {
new (itr++) Any(val);
}
return p;
@@ -163,24 +121,7 @@ class ArrayObj : public Object, public
details::InplaceArrayBase<ArrayObj, TVMFF
private:
/*! \return Size of initialized memory, used by InplaceArrayBase. */
- size_t GetSize() const { return this->size_; }
-
- /*! \return begin mutable iterator */
- Any* MutableBegin() const { return static_cast<Any*>(this->data_); }
-
- /*! \return end mutable iterator */
- Any* MutableEnd() const { return MutableBegin() + size_; }
-
- /*!
- * \brief Emplace a new element at the back of the array
- * \param idx The index of the element.
- * \param args The arguments to construct the new element
- */
- template <typename... Args>
- void EmplaceInit(size_t idx, Args&&... args) {
- Any* itr = MutableBegin() + idx;
- new (itr) Any(std::forward<Args>(args)...);
- }
+ size_t GetSize() const { return TVMFFISeqCell::size; }
/*!
* \brief Create an ArrayObj with the given capacity.
@@ -189,9 +130,10 @@ class ArrayObj : public Object, public
details::InplaceArrayBase<ArrayObj, TVMFF
*/
static ObjectPtr<ArrayObj> Empty(int64_t n = kInitSize) {
ObjectPtr<ArrayObj> p = make_inplace_array_object<ArrayObj, Any>(n);
- p->capacity_ = n;
- p->size_ = 0;
- p->data_ = p->AddressOf(0);
+ p->TVMFFISeqCell::capacity = n;
+ p->TVMFFISeqCell::size = 0;
+ p->data = p->AddressOf(0);
+ p->data_deleter = nullptr;
return p;
}
@@ -213,79 +155,6 @@ class ArrayObj : public Object, public
details::InplaceArrayBase<ArrayObj, TVMFF
return this;
}
- /*!
- * \brief Move elements from right to left, requires src_begin > dst
- * \param dst Destination
- * \param src_begin The start point of copy (inclusive)
- * \param src_end The end point of copy (exclusive)
- * \return Self
- */
- ArrayObj* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) {
- Any* from = MutableBegin() + src_begin;
- Any* to = MutableBegin() + dst;
- while (src_begin++ != src_end) {
- *to++ = std::move(*from++);
- }
- return this;
- }
-
- /*!
- * \brief Move elements from left to right, requires src_begin < dst
- * \param dst Destination
- * \param src_begin The start point of move (inclusive)
- * \param src_end The end point of move (exclusive)
- * \return Self
- */
- ArrayObj* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end)
{
- Any* from = MutableBegin() + src_end;
- Any* to = MutableBegin() + (src_end - src_begin + dst);
- while (src_begin++ != src_end) {
- *--to = std::move(*--from);
- }
- return this;
- }
-
- /*!
- * \brief Enlarges the size of the array
- * \param delta Size enlarged, should be positive
- * \param val Default value
- * \return Self
- */
- ArrayObj* EnlargeBy(int64_t delta, const Any& val = Any()) {
- Any* itr = MutableEnd();
- while (delta-- > 0) {
- new (itr++) Any(val);
- ++size_;
- }
- return this;
- }
-
- /*!
- * \brief Shrinks the size of the array
- * \param delta Size shrinked, should be positive
- * \return Self
- */
- ArrayObj* ShrinkBy(int64_t delta) {
- Any* itr = MutableEnd();
- while (delta-- > 0) {
- (--itr)->Any::~Any();
- --size_;
- }
- return this;
- }
-
- /*! \brief Data pointer to the first element of the array */
- void* data_;
- /*! \brief Number of elements used */
- int64_t size_;
- /*! \brief Number of elements allocated */
- int64_t capacity_;
- /*!
- * \brief Optional data deleter when data is allocated separately
- * and its deletion is not managed by ArrayObj::deleter_.
- */
- void (*data_deleter_)(void*) = nullptr;
-
/*! \brief Initial size of ArrayObj */
static constexpr int64_t kInitSize = 4;
@@ -520,48 +389,45 @@ class Array : public ObjectRef {
* \param i The index
* \return the i-th element.
*/
- const T operator[](int64_t i) const {
+ T operator[](int64_t i) const {
ArrayObj* p = GetArrayObj();
if (p == nullptr) {
TVM_FFI_THROW(IndexError) << "cannot index a null array";
}
- if (i < 0 || i >= p->size_) {
- TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size "
<< p->size_;
- }
- return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*(p->begin() + i));
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(p->at(i));
}
/*! \return The size of the array */
size_t size() const {
ArrayObj* p = GetArrayObj();
- return p == nullptr ? 0 : GetArrayObj()->size_;
+ return p == nullptr ? 0 : p->size();
}
/*! \return The capacity of the array */
size_t capacity() const {
ArrayObj* p = GetArrayObj();
- return p == nullptr ? 0 : GetArrayObj()->capacity_;
+ return p == nullptr ? 0 : p->SeqBaseObj::capacity();
}
/*! \return Whether array is empty */
bool empty() const { return size() == 0; }
/*! \return The first element of the array */
- const T front() const {
+ T front() const {
ArrayObj* p = GetArrayObj();
- if (p == nullptr || p->size_ == 0) {
- TVM_FFI_THROW(IndexError) << "cannot index a empty array";
+ if (p == nullptr) {
+ TVM_FFI_THROW(IndexError) << "cannot index a null array";
}
- return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*(p->begin()));
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(p->front());
}
/*! \return The last element of the array */
- const T back() const {
+ T back() const {
ArrayObj* p = GetArrayObj();
- if (p == nullptr || p->size_ == 0) {
- TVM_FFI_THROW(IndexError) << "cannot index a empty array";
+ if (p == nullptr) {
+ TVM_FFI_THROW(IndexError) << "cannot index a null array";
}
- return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(*(p->end() - 1));
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(p->back());
}
public:
@@ -572,7 +438,7 @@ class Array : public ObjectRef {
*/
void push_back(const T& item) {
ArrayObj* p = CopyOnWrite(1);
- p->EmplaceInit(p->size_++, item);
+ p->EmplaceInit(p->TVMFFISeqCell::size++, item);
}
/*!
@@ -582,7 +448,7 @@ class Array : public ObjectRef {
template <typename... Args>
void emplace_back(Args&&... args) {
ArrayObj* p = CopyOnWrite(1);
- p->EmplaceInit(p->size_++, std::forward<Args>(args)...);
+ p->EmplaceInit(p->TVMFFISeqCell::size++, std::forward<Args>(args)...);
}
/*!
@@ -595,12 +461,7 @@ class Array : public ObjectRef {
TVM_FFI_THROW(RuntimeError) << "cannot insert a null array";
}
int64_t idx = std::distance(begin(), position);
- int64_t size = GetArrayObj()->size_;
- auto addr = CopyOnWrite(1) //
- ->EnlargeBy(1) //
- ->MoveElementsRight(idx + 1, idx, size) //
- ->MutableBegin();
- new (addr + idx) Any(val);
+ CopyOnWrite(1)->insert(idx, Any(val));
}
/*!
@@ -613,20 +474,13 @@ class Array : public ObjectRef {
void insert(iterator position, IterType first, IterType last) {
static_assert(is_valid_iterator_v<T, IterType>,
"IterType cannot be inserted into a tvm::Array<T>");
-
- if (first == last) {
- return;
- }
+ if (first == last) return;
if (data_ == nullptr) {
TVM_FFI_THROW(RuntimeError) << "cannot insert a null array";
}
int64_t idx = std::distance(begin(), position);
- int64_t size = GetArrayObj()->size_;
int64_t numel = std::distance(first, last);
- CopyOnWrite(numel)
- ->EnlargeBy(numel)
- ->MoveElementsRight(idx + numel, idx, size)
- ->InitRange(idx, first, last);
+ CopyOnWrite(numel)->insert(idx, first, last);
}
/*! \brief Remove the last item of the list */
@@ -634,11 +488,7 @@ class Array : public ObjectRef {
if (data_ == nullptr) {
TVM_FFI_THROW(RuntimeError) << "cannot pop_back a null array";
}
- int64_t size = GetArrayObj()->size_;
- if (size == 0) {
- TVM_FFI_THROW(RuntimeError) << "cannot pop_back an empty array";
- }
- CopyOnWrite()->ShrinkBy(1);
+ CopyOnWrite()->pop_back();
}
/*!
@@ -649,15 +499,8 @@ class Array : public ObjectRef {
if (data_ == nullptr) {
TVM_FFI_THROW(RuntimeError) << "cannot erase a null array";
}
- int64_t st = std::distance(begin(), position);
- int64_t size = GetArrayObj()->size_;
- if (st < 0 || st >= size) {
- TVM_FFI_THROW(RuntimeError) << "cannot erase at index " << st << ",
because Array size is "
- << size;
- }
- CopyOnWrite() //
- ->MoveElementsLeft(st, st + 1, size) //
- ->ShrinkBy(1);
+ int64_t idx = std::distance(begin(), position);
+ CopyOnWrite()->erase(idx);
}
/*!
@@ -666,25 +509,13 @@ class Array : public ObjectRef {
* \param last The end iterator of the range
*/
void erase(iterator first, iterator last) {
- if (first == last) {
- return;
- }
+ if (first == last) return;
if (data_ == nullptr) {
TVM_FFI_THROW(RuntimeError) << "cannot erase a null array";
}
- int64_t size = GetArrayObj()->size_;
int64_t st = std::distance(begin(), first);
int64_t ed = std::distance(begin(), last);
- if (st >= ed) {
- TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ",
" << ed << ")";
- }
- if (st < 0 || st > size || ed < 0 || ed > size) {
- TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ",
" << ed << ")"
- << ", because array size is " << size;
- }
- CopyOnWrite() //
- ->MoveElementsLeft(st, ed, size) //
- ->ShrinkBy(ed - st);
+ CopyOnWrite()->erase(st, ed);
}
/*!
@@ -699,11 +530,11 @@ class Array : public ObjectRef {
SwitchContainer(n);
return;
}
- int64_t size = GetArrayObj()->size_;
- if (size < n) {
- CopyOnWrite(n - size)->EnlargeBy(n - size);
- } else if (size > n) {
- CopyOnWrite()->ShrinkBy(size - n);
+ int64_t cur_size = GetArrayObj()->TVMFFISeqCell::size;
+ if (cur_size < n) {
+ CopyOnWrite(n - cur_size)->resize(n);
+ } else if (cur_size > n) {
+ CopyOnWrite()->resize(n);
}
}
@@ -712,7 +543,7 @@ class Array : public ObjectRef {
* \param n lower bound of the capacity
*/
void reserve(int64_t n) {
- if (data_ == nullptr || n > GetArrayObj()->capacity_) {
+ if (data_ == nullptr || n >
static_cast<int64_t>(GetArrayObj()->SeqBaseObj::capacity())) {
SwitchContainer(n);
}
}
@@ -764,13 +595,7 @@ class Array : public ObjectRef {
* \param i The index
* \param value The value to be setted.
*/
- void Set(int64_t i, T value) {
- ArrayObj* p = this->CopyOnWrite();
- if (i < 0 || i >= p->size_) {
- TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size "
<< p->size_;
- }
- *(p->MutableBegin() + i) = std::move(value);
- }
+ void Set(int64_t i, T value) { CopyOnWrite()->SetItem(i, std::move(value)); }
/*! \return The underlying ArrayObj */
ArrayObj* GetArrayObj() const { return static_cast<ArrayObj*>(data_.get()); }
@@ -823,7 +648,7 @@ class Array : public ObjectRef {
TVM_FFI_THROW(ValueError) << "cannot construct an Array of negative
size";
}
ArrayObj* p = GetArrayObj();
- if (p != nullptr && data_.unique() && p->capacity_ >= cap) {
+ if (p != nullptr && data_.unique() && p->TVMFFISeqCell::capacity >= cap) {
// do not have to make new space
p->clear();
} else {
@@ -833,7 +658,7 @@ class Array : public ObjectRef {
}
// To ensure exception safety, size is only incremented after the
initialization succeeds
Any* itr = p->MutableBegin();
- for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) {
+ for (int64_t& i = p->TVMFFISeqCell::size = 0; i < cap; ++i, ++first,
++itr) {
new (itr) Any(*first);
}
}
@@ -885,11 +710,11 @@ class Array : public ObjectRef {
const int64_t kInitSize = ArrayObj::kInitSize;
return SwitchContainer(std::max(kInitSize, reserve_extra));
}
- if (p->capacity_ >= p->size_ + reserve_extra) {
+ if (p->TVMFFISeqCell::capacity >= p->TVMFFISeqCell::size + reserve_extra) {
return CopyOnWrite();
}
- int64_t cap = p->capacity_ * ArrayObj::kIncFactor;
- cap = std::max(cap, p->size_ + reserve_extra);
+ int64_t cap = p->TVMFFISeqCell::capacity * ArrayObj::kIncFactor;
+ cap = std::max(cap, p->TVMFFISeqCell::size + reserve_extra);
return SwitchContainer(cap);
}
@@ -1065,79 +890,16 @@ template <typename T>
inline constexpr bool use_default_type_traits_v<Array<T>> = false;
template <typename T>
-struct TypeTraits<Array<T>> : public ObjectRefTypeTraitsBase<Array<T>> {
+struct TypeTraits<Array<T>> : public SeqTypeTraitsBase<TypeTraits<Array<T>>,
Array<T>, T> {
static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray;
- using ObjectRefTypeTraitsBase<Array<T>>::CopyFromAnyViewAfterCheck;
-
- TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
- if (src->type_index != TypeIndex::kTVMFFIArray) {
- return TypeTraitsBase::GetMismatchTypeInfo(src);
- }
- if constexpr (!std::is_same_v<T, Any>) {
- const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
- for (size_t i = 0; i < n->size(); i++) {
- const Any& any_v = (*n)[static_cast<int64_t>(i)];
- // CheckAnyStrict is cheaper than try_cast<T>
- if (details::AnyUnsafe::CheckAnyStrict<T>(any_v)) continue;
- // try see if p is convertible to T
- if (any_v.try_cast<T>()) continue;
- // now report the accurate mismatch information
- return "Array[index " + std::to_string(i) + ": " +
- details::AnyUnsafe::GetMismatchTypeInfo<T>(any_v) + "]";
- }
- }
- TVM_FFI_THROW(InternalError) << "Cannot reach here";
- TVM_FFI_UNREACHABLE();
- }
-
- TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
- if (src->type_index != TypeIndex::kTVMFFIArray) return false;
- if constexpr (std::is_same_v<T, Any>) {
- return true;
- } else {
- const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
- for (const Any& any_v : *n) {
- if (!details::AnyUnsafe::CheckAnyStrict<T>(any_v)) return false;
- }
- return true;
- }
- }
+ static constexpr int32_t kPrimaryTypeIndex = TypeIndex::kTVMFFIArray;
+ static constexpr int32_t kOtherTypeIndex = TypeIndex::kTVMFFIList;
+ static constexpr const char* kTypeName = "Array";
+ static constexpr const char* kStaticTypeKey = StaticTypeKey::kTVMFFIArray;
- TVM_FFI_INLINE static std::optional<Array<T>> TryCastFromAnyView(const
TVMFFIAny* src) {
- // try to run conversion.
- if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt;
- if constexpr (!std::is_same_v<T, Any>) {
- const ArrayObj* n = reinterpret_cast<const ArrayObj*>(src->v_obj);
- bool storage_check = [&]() {
- for (const Any& any_v : *n) {
- if (!details::AnyUnsafe::CheckAnyStrict<T>(any_v)) return false;
- }
- return true;
- }();
- // fast path, if storage check passes, we can return the array directly.
- if (storage_check) {
- return CopyFromAnyViewAfterCheck(src);
- }
- // slow path, try to run a conversion to Array<T>
- Array<T> result;
- result.reserve(n->size());
- for (const Any& any_v : *n) {
- if (auto opt_v = any_v.try_cast<T>()) {
- result.push_back(*std::move(opt_v));
- } else {
- return std::nullopt;
- }
- }
- return result;
- } else {
- return CopyFromAnyViewAfterCheck(src);
- }
- }
-
- TVM_FFI_INLINE static std::string TypeStr() { return "Array<" +
details::Type2Str<T>::v() + ">"; }
TVM_FFI_INLINE static std::string TypeSchema() {
std::ostringstream oss;
- oss << R"({"type":")" << StaticTypeKey::kTVMFFIArray << R"(","args":[)";
+ oss << R"({"type":")" << kStaticTypeKey << R"(","args":[)";
oss << details::TypeSchema<T>::v();
oss << "]}";
return oss.str();
@@ -1151,4 +913,5 @@ inline constexpr bool type_contains_v<Array<T>, Array<U>>
= type_contains_v<T, U
} // namespace ffi
} // namespace tvm
+
#endif // TVM_FFI_CONTAINER_ARRAY_H_
diff --git a/include/tvm/ffi/container/list.h b/include/tvm/ffi/container/list.h
new file mode 100644
index 0000000..6e52be6
--- /dev/null
+++ b/include/tvm/ffi/container/list.h
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ffi/container/list.h
+ * \brief Mutable list type.
+ *
+ * tvm::ffi::List<Any> is an erased mutable sequence container.
+ */
+#ifndef TVM_FFI_CONTAINER_LIST_H_
+#define TVM_FFI_CONTAINER_LIST_H_
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/seq_base.h>
+#include <tvm/ffi/object.h>
+
+#include <algorithm>
+#include <initializer_list>
+#include <sstream>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace ffi {
+
+/*! \brief List node content in list */
+class ListObj : public SeqBaseObj {
+ public:
+ /*!
+ * \brief Constructs a container with n elements. Each element is a copy of
val
+ * \param n The size of the container
+ * \param val The init value
+ * \return Ref-counted ListObj requested
+ */
+ static ObjectPtr<ListObj> CreateRepeated(int64_t n, const Any& val) {
+ ObjectPtr<ListObj> p = ListObj::Empty(n);
+ Any* itr = p->MutableBegin();
+ for (int64_t& i = p->TVMFFISeqCell::size = 0; i < n; ++i) {
+ new (itr++) Any(val);
+ }
+ return p;
+ }
+
+ /// \cond Doxygen_Suppress
+ static constexpr const int32_t _type_index = TypeIndex::kTVMFFIList;
+ static const constexpr bool _type_final = true;
+ TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIList, ListObj,
Object);
+ /// \endcond
+
+ private:
+ /*!
+ * \brief Ensure the list has at least n slots.
+ * \param n The lower bound of required capacity.
+ *
+ * \note Leak-safety: Any's move constructor is noexcept, so the
+ * move loop below cannot throw and leave the buffer in a
+ * partially-constructed state.
+ */
+ void Reserve(int64_t n) {
+ if (n <= TVMFFISeqCell::capacity) {
+ return;
+ }
+ Any* old_data = MutableBegin();
+ Any* new_data = static_cast<Any*>(::operator new(sizeof(Any) *
static_cast<size_t>(n)));
+ for (int64_t i = 0; i < TVMFFISeqCell::size; ++i) {
+ new (new_data + i) Any(std::move(old_data[i]));
+ }
+ for (int64_t j = 0; j < TVMFFISeqCell::size; ++j) {
+ (old_data + j)->Any::~Any();
+ }
+ data_deleter(data);
+ data = new_data;
+ TVMFFISeqCell::capacity = n;
+ }
+
+ /*!
+ * \brief Create an empty ListObj with the given capacity.
+ * \param n Required capacity
+ * \return Ref-counted ListObj requested
+ */
+ static ObjectPtr<ListObj> Empty(int64_t n = kInitSize) {
+ if (n < 0) {
+ TVM_FFI_THROW(ValueError) << "cannot construct a List of negative size";
+ }
+ ObjectPtr<ListObj> p = make_object<ListObj>();
+ p->TVMFFISeqCell::capacity = n;
+ p->TVMFFISeqCell::size = 0;
+ p->data = n == 0 ? nullptr : static_cast<void*>(::operator new(sizeof(Any)
* n));
+ p->data_deleter = RawDataDeleter;
+ return p;
+ }
+
+ static void RawDataDeleter(void* data) { ::operator delete(data); }
+
+ /*! \brief Initial size of ListObj */
+ static constexpr int64_t kInitSize = 4;
+ /*! \brief Expansion factor of the List */
+ static constexpr int64_t kIncFactor = 2;
+
+ template <typename, typename>
+ friend class List;
+
+ template <typename, typename>
+ friend struct TypeTraits;
+};
+
+/*!
+ * \brief List, container representing a mutable contiguous sequence of
ObjectRefs.
+ *
+ * Unlike Array, List is mutable and does not implement copy-on-write.
+ * Mutations happen directly on the underlying shared ListObj.
+ *
+ * \note Thread safety: List is **not** thread-safe. Concurrent reads and
writes
+ * from multiple threads require external synchronization.
+ *
+ * \warning Because List elements are stored as `Any` (which may hold
`ObjectRef`
+ * pointers), it is possible to create reference cycles (e.g. a List
that
+ * contains itself). Such cycles will **not** be collected by the
reference-
+ * counting mechanism alone; avoid them in long-lived data structures.
+ *
+ * \tparam T The content Value type, must be compatible with tvm::ffi::Any
+ */
+template <typename T, typename = typename
std::enable_if_t<details::storage_enabled_v<T>>>
+class List : public ObjectRef {
+ public:
+ /*! \brief The value type of the list */
+ using value_type = T;
+
+ /*! \brief Construct a List with UnsafeInit */
+ explicit List(UnsafeInit tag) : ObjectRef(tag) {}
+ /*! \brief default constructor */
+ List() { data_ = ListObj::Empty(0); } //
NOLINT(modernize-use-equals-default)
+ /*! \brief Move constructor */
+ List(List<T>&& other) // NOLINT(google-explicit-constructor)
+ : ObjectRef(std::move(other.data_)) {}
+ /*! \brief Copy constructor */
+ List(const List<T>& other) : ObjectRef(other.data_) {} //
NOLINT(google-explicit-constructor)
+
+ /*!
+ * \brief Constructor from another list
+ * \tparam U The value type of the other list
+ */
+ template <typename U, typename =
std::enable_if_t<details::type_contains_v<T, U>>>
+ List(List<U>&& other) // NOLINT(google-explicit-constructor)
+ : ObjectRef(std::move(other.data_)) {}
+
+ /*!
+ * \brief Constructor from another list
+ * \tparam U The value type of the other list
+ */
+ template <typename U, typename =
std::enable_if_t<details::type_contains_v<T, U>>>
+ List(const List<U>& other) // NOLINT(google-explicit-constructor)
+ : ObjectRef(other.data_) {}
+
+ /*!
+ * \brief Move assignment from another list.
+ * \param other The other list.
+ */
+ TVM_FFI_INLINE List<T>& operator=(List<T>&& other) {
+ data_ = std::move(other.data_);
+ return *this;
+ }
+
+ /*!
+ * \brief Assignment from another list.
+ * \param other The other list.
+ */
+ TVM_FFI_INLINE List<T>& operator=(const List<T>& other) {
+ data_ = other.data_;
+ return *this;
+ }
+
+ /*!
+ * \brief Move assignment from another list.
+ * \param other The other list.
+ * \tparam U The value type of the other list.
+ */
+ template <typename U, typename =
std::enable_if_t<details::type_contains_v<T, U>>>
+ TVM_FFI_INLINE List<T>& operator=(List<U>&& other) {
+ data_ = std::move(other.data_);
+ return *this;
+ }
+
+ /*!
+ * \brief Assignment from another list.
+ * \param other The other list.
+ * \tparam U The value type of the other list.
+ */
+ template <typename U, typename =
std::enable_if_t<details::type_contains_v<T, U>>>
+ TVM_FFI_INLINE List<T>& operator=(const List<U>& other) {
+ data_ = other.data_;
+ return *this;
+ }
+
+ /*! \brief Constructor from pointer */
+ explicit List(ObjectPtr<Object> n) : ObjectRef(std::move(n)) {}
+
+ /*!
+ * \brief Constructor from iterator
+ * \param first begin of iterator
+ * \param last end of iterator
+ * \tparam IterType The type of iterator
+ */
+ template <typename IterType>
+ List(IterType first, IterType last) { //
NOLINT(performance-unnecessary-value-param)
+ static_assert(is_valid_iterator_v<T, IterType>,
+ "IterType cannot be inserted into a tvm::List<T>");
+ Assign(first, last);
+ }
+
+ /*! \brief constructor from initializer list */
+ List(std::initializer_list<T> init) { // NOLINT(*)
+ Assign(init.begin(), init.end());
+ }
+
+ /*! \brief constructor from vector */
+ List(const std::vector<T>& init) { // NOLINT(*)
+ Assign(init.begin(), init.end());
+ }
+
+ /*!
+ * \brief Constructs a container with n elements. Each element is a copy of
val
+ * \param n The size of the container
+ * \param val The init value
+ */
+ explicit List(const size_t n, const T& val) { data_ =
ListObj::CreateRepeated(n, val); }
+
+ public:
+ // iterators
+ /// \cond Doxygen_Suppress
+ struct ValueConverter {
+ using ResultType = T;
+ static T convert(const Any& n) { return
details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(n); }
+ };
+ /// \endcond
+
+ /*! \brief The iterator type of the list */
+ using iterator = details::IterAdapter<ValueConverter, const Any*>;
+ /*! \brief The reverse iterator type of the list */
+ using reverse_iterator = details::ReverseIterAdapter<ValueConverter, const
Any*>;
+
+ /*! \return begin iterator */
+ iterator begin() const { return iterator(GetListObj()->begin()); }
+ /*! \return end iterator */
+ iterator end() const { return iterator(GetListObj()->end()); }
+ /*! \return rbegin iterator */
+ reverse_iterator rbegin() const { return
reverse_iterator(GetListObj()->end() - 1); }
+ /*! \return rend iterator */
+ reverse_iterator rend() const { return
reverse_iterator(GetListObj()->begin() - 1); }
+
+ public:
+ // const methods in std::vector
+ /*!
+ * \brief Immutably read i-th element from list.
+ * \param i The index
+ * \return the i-th element.
+ */
+ T operator[](int64_t i) const {
+ ListObj* p = GetListObj();
+ if (p == nullptr) {
+ TVM_FFI_THROW(IndexError) << "cannot index a null list";
+ }
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(p->at(i));
+ }
+
+ /*! \return The size of the list */
+ size_t size() const {
+ ListObj* p = GetListObj();
+ return p == nullptr ? 0 : p->size();
+ }
+
+ /*! \return The capacity of the list */
+ size_t capacity() const {
+ ListObj* p = GetListObj();
+ return p == nullptr ? 0 : p->SeqBaseObj::capacity();
+ }
+
+ /*! \return Whether list is empty */
+ bool empty() const { return size() == 0; }
+
+ /*! \return The first element of the list */
+ T front() const {
+ ListObj* p = GetListObj();
+ if (p == nullptr) {
+ TVM_FFI_THROW(IndexError) << "cannot index a null list";
+ }
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(p->front());
+ }
+
+ /*! \return The last element of the list */
+ T back() const {
+ ListObj* p = GetListObj();
+ if (p == nullptr) {
+ TVM_FFI_THROW(IndexError) << "cannot index a null list";
+ }
+ return details::AnyUnsafe::CopyFromAnyViewAfterCheck<T>(p->back());
+ }
+
+ public:
+ // mutation in std::vector
+ /*!
+ * \brief push a new item to the back of the list
+ * \param item The item to be pushed.
+ */
+ void push_back(const T& item) {
+ ListObj* p = EnsureCapacity(1);
+ p->EmplaceInit(p->TVMFFISeqCell::size++, item);
+ }
+
+ /*!
+ * \brief Emplace a new element at the back of the list
+ * \param args The arguments to construct the new element
+ */
+ template <typename... Args>
+ void emplace_back(Args&&... args) {
+ ListObj* p = EnsureCapacity(1);
+ p->EmplaceInit(p->TVMFFISeqCell::size++, std::forward<Args>(args)...);
+ }
+
+ /*!
+ * \brief Insert an element into the given position
+ * \param position An iterator pointing to the insertion point
+ * \param val The element to insert
+ */
+ void insert(iterator position, const T& val) {
+ if (data_ == nullptr) {
+ TVM_FFI_THROW(RuntimeError) << "cannot insert to a null list";
+ }
+ int64_t idx = std::distance(begin(), position);
+ EnsureCapacity(1)->insert(idx, Any(val));
+ }
+
+ /*!
+ * \brief Insert a range of elements into the given position
+ * \param position An iterator pointing to the insertion point
+ * \param first The begin iterator of the range
+ * \param last The end iterator of the range
+ */
+ template <typename IterType>
+ void insert(iterator position, IterType first, IterType last) {
+ static_assert(is_valid_iterator_v<T, IterType>,
+ "IterType cannot be inserted into a tvm::List<T>");
+ if (first == last) return;
+ if (data_ == nullptr) {
+ TVM_FFI_THROW(RuntimeError) << "cannot insert to a null list";
+ }
+ int64_t idx = std::distance(begin(), position);
+ int64_t numel = std::distance(first, last);
+ EnsureCapacity(numel)->insert(idx, first, last);
+ }
+
+ /*! \brief Remove the last item of the list */
+ void pop_back() {
+ if (data_ == nullptr) {
+ TVM_FFI_THROW(RuntimeError) << "cannot pop_back a null list";
+ }
+ GetListObj()->pop_back();
+ }
+
+ /*!
+ * \brief Erase an element on the given position
+ * \param position An iterator pointing to the element to be erased
+ */
+ void erase(iterator position) {
+ if (data_ == nullptr) {
+ TVM_FFI_THROW(RuntimeError) << "cannot erase a null list";
+ }
+ int64_t idx = std::distance(begin(), position);
+ GetListObj()->erase(idx);
+ }
+
+ /*!
+ * \brief Erase a given range of elements
+ * \param first The begin iterator of the range
+ * \param last The end iterator of the range
+ */
+ void erase(iterator first, iterator last) {
+ if (first == last) return;
+ if (data_ == nullptr) {
+ TVM_FFI_THROW(RuntimeError) << "cannot erase a null list";
+ }
+ int64_t st = std::distance(begin(), first);
+ int64_t ed = std::distance(begin(), last);
+ GetListObj()->erase(st, ed);
+ }
+
+ /*!
+ * \brief Resize the list.
+ * \param n The new size.
+ */
+ void resize(int64_t n) {
+ if (n < 0) {
+ TVM_FFI_THROW(ValueError) << "cannot resize a List to negative size";
+ }
+ EnsureCapacity(std::max<int64_t>(0, n -
static_cast<int64_t>(size())))->resize(n);
+ }
+
+ /*!
+ * \brief Make sure the list has the capacity of at least n
+ * \param n lower bound of the capacity
+ */
+ void reserve(int64_t n) { EnsureListObj()->Reserve(n); }
+
+ /*! \brief Release reference to all the elements */
+ void clear() {
+ if (data_ != nullptr) {
+ GetListObj()->clear();
+ }
+ }
+
+ public:
+ // List's own methods
+ /*!
+ * \brief set i-th element of the list.
+ * \param i The index
+ * \param value The value to be set.
+ */
+ void Set(int64_t i, T value) { EnsureListObj()->SetItem(i,
std::move(value)); }
+
+ /*! \return The underlying ListObj */
+ ListObj* GetListObj() const { return static_cast<ListObj*>(data_.get()); }
+
+ /*!
+ * \brief reset the list to content from iterator.
+ * \param first begin of iterator
+ * \param last end of iterator
+ * \tparam IterType The type of iterator
+ */
+ template <typename IterType>
+ void Assign(IterType first, IterType last) { //
NOLINT(performance-unnecessary-value-param)
+ int64_t cap = std::distance(first, last);
+ if (cap < 0) {
+ TVM_FFI_THROW(ValueError) << "cannot construct a List of negative size";
+ }
+ ListObj* p = EnsureListObj();
+ p->Reserve(cap);
+ p->clear();
+ Any* itr = p->MutableBegin();
+ for (int64_t& i = p->TVMFFISeqCell::size = 0; i < cap; ++i, ++first,
++itr) {
+ new (itr) Any(*first);
+ }
+ }
+
+ /*! \brief specify container node */
+ using ContainerType = ListObj;
+
+ private:
+ /*!
+ * \brief Ensure the list object exists and has room for reserve_extra new
entries.
+ * \param reserve_extra Number of extra slots needed
+ * \return ListObj pointer
+ */
+ ListObj* EnsureCapacity(int64_t reserve_extra) {
+ ListObj* p = EnsureListObj();
+ if (p->TVMFFISeqCell::capacity >= p->TVMFFISeqCell::size + reserve_extra) {
+ return p;
+ }
+ int64_t cap = p->TVMFFISeqCell::capacity * ListObj::kIncFactor;
+ cap = std::max(cap, p->TVMFFISeqCell::size + reserve_extra);
+ p->Reserve(cap);
+ return p;
+ }
+
+ /*!
+ * \brief Ensure this list has a container object.
+ * \return ListObj pointer
+ */
+ ListObj* EnsureListObj() {
+ if (data_ == nullptr) {
+ data_ = ListObj::Empty();
+ }
+ return static_cast<ListObj*>(data_.get());
+ }
+
+ template <typename, typename>
+ friend class List;
+};
+
+// Traits for List
+template <typename T>
+inline constexpr bool use_default_type_traits_v<List<T>> = false;
+
+template <typename T>
+struct TypeTraits<List<T>> : public SeqTypeTraitsBase<TypeTraits<List<T>>,
List<T>, T> {
+ static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIList;
+ static constexpr int32_t kPrimaryTypeIndex = TypeIndex::kTVMFFIList;
+ static constexpr int32_t kOtherTypeIndex = TypeIndex::kTVMFFIArray;
+ static constexpr const char* kTypeName = "List";
+ static constexpr const char* kStaticTypeKey = StaticTypeKey::kTVMFFIList;
+
+ TVM_FFI_INLINE static std::string TypeSchema() {
+ std::ostringstream oss;
+ oss << R"({"type":")" << kStaticTypeKey << R"(","args":[)";
+ oss << details::TypeSchema<T>::v();
+ oss << "]}";
+ return oss.str();
+ }
+};
+
+namespace details {
+template <typename T, typename U>
+inline constexpr bool type_contains_v<List<T>, List<U>> = type_contains_v<T,
U>;
+} // namespace details
+
+} // namespace ffi
+} // namespace tvm
+
+#endif // TVM_FFI_CONTAINER_LIST_H_
diff --git a/include/tvm/ffi/container/seq_base.h
b/include/tvm/ffi/container/seq_base.h
new file mode 100644
index 0000000..4d5a361
--- /dev/null
+++ b/include/tvm/ffi/container/seq_base.h
@@ -0,0 +1,365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ffi/container/seq_base.h
+ * \brief Base class for sequence containers (Array, List).
+ */
+#ifndef TVM_FFI_CONTAINER_SEQ_BASE_H_
+#define TVM_FFI_CONTAINER_SEQ_BASE_H_
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/object.h>
+
+#include <algorithm>
+#include <cstddef>
+#include <iterator>
+#include <utility>
+
+namespace tvm {
+namespace ffi {
+
+/*!
+ * \brief Base class for sequence containers (ArrayObj, ListObj).
+ *
+ * SeqBaseObj is transparent to the FFI type system (no type index),
+ * following the same pattern as BytesObjBase.
+ */
+class SeqBaseObj : public Object, protected TVMFFISeqCell {
+ public:
+ SeqBaseObj() {
+ data = nullptr;
+ TVMFFISeqCell::size = 0;
+ TVMFFISeqCell::capacity = 0;
+ data_deleter = nullptr;
+ }
+
+ ~SeqBaseObj() {
+ Any* begin = MutableBegin();
+ for (int64_t i = 0; i < TVMFFISeqCell::size; ++i) {
+ (begin + i)->Any::~Any();
+ }
+ if (data_deleter != nullptr) {
+ data_deleter(data);
+ }
+ }
+
+ /*! \return The size of the sequence */
+ size_t size() const { return static_cast<size_t>(TVMFFISeqCell::size); }
+
+ /*! \return The capacity of the sequence */
+ size_t capacity() const { return
static_cast<size_t>(TVMFFISeqCell::capacity); }
+
+ /*! \return Whether the sequence is empty */
+ bool empty() const { return TVMFFISeqCell::size == 0; }
+
+ /*!
+ * \brief Read i-th element from the sequence.
+ * \param i The index
+ * \return the i-th element.
+ */
+ const Any& at(int64_t i) const { return this->operator[](i); }
+
+ /*!
+ * \brief Read i-th element from the sequence.
+ * \param i The index
+ * \return the i-th element.
+ */
+ const Any& operator[](int64_t i) const {
+ if (i < 0 || i >= TVMFFISeqCell::size) {
+ TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " <<
TVMFFISeqCell::size;
+ }
+ return static_cast<Any*>(data)[i];
+ }
+
+ /*! \return The first element */
+ const Any& front() const {
+ if (TVMFFISeqCell::size == 0) {
+ TVM_FFI_THROW(IndexError) << "front() on empty sequence";
+ }
+ return static_cast<Any*>(data)[0];
+ }
+
+ /*! \return The last element */
+ const Any& back() const {
+ if (TVMFFISeqCell::size == 0) {
+ TVM_FFI_THROW(IndexError) << "back() on empty sequence";
+ }
+ return static_cast<Any*>(data)[TVMFFISeqCell::size - 1];
+ }
+
+ /*! \return begin constant iterator */
+ const Any* begin() const { return static_cast<Any*>(data); }
+
+ /*! \return end constant iterator */
+ const Any* end() const { return begin() + TVMFFISeqCell::size; }
+
+ /*! \brief Release reference to all the elements */
+ void clear() {
+ Any* itr = MutableEnd();
+ while (TVMFFISeqCell::size > 0) {
+ (--itr)->Any::~Any();
+ --TVMFFISeqCell::size;
+ }
+ }
+
+ /*!
+ * \brief Set i-th element of the sequence in-place
+ * \param i The index
+ * \param item The value to be set
+ */
+ void SetItem(int64_t i, Any item) {
+ if (i < 0 || i >= TVMFFISeqCell::size) {
+ TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " <<
TVMFFISeqCell::size;
+ }
+ static_cast<Any*>(data)[i] = std::move(item);
+ }
+
+ /*! \brief Remove the last element */
+ void pop_back() {
+ if (TVMFFISeqCell::size == 0) {
+ TVM_FFI_THROW(IndexError) << "pop_back on empty sequence";
+ }
+ ShrinkBy(1);
+ }
+
+ /*!
+ * \brief Erase element at position idx
+ * \param idx The index to erase
+ */
+ void erase(int64_t idx) {
+ if (idx < 0 || idx >= TVMFFISeqCell::size) {
+ TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " <<
TVMFFISeqCell::size;
+ }
+ MoveElementsLeft(idx, idx + 1, TVMFFISeqCell::size);
+ ShrinkBy(1);
+ }
+
+ /*!
+ * \brief Erase elements in half-open range [first, last)
+ * \param first Start index (inclusive)
+ * \param last End index (exclusive)
+ */
+ void erase(int64_t first, int64_t last) {
+ if (first == last) return;
+ if (first < 0 || last > TVMFFISeqCell::size || first >= last) {
+ TVM_FFI_THROW(IndexError) << "Erase range [" << first << ", " << last <<
") out of bounds "
+ << TVMFFISeqCell::size;
+ }
+ MoveElementsLeft(first, last, TVMFFISeqCell::size);
+ ShrinkBy(last - first);
+ }
+
+ /*!
+ * \brief Insert element at position idx
+ * \param idx The index to insert at
+ * \param item The value to insert
+ * \note Caller must ensure capacity >= size + 1
+ */
+ void insert(int64_t idx, Any item) {
+ int64_t sz = TVMFFISeqCell::size;
+ if (idx < 0 || idx > sz) {
+ TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds [0, " <<
sz << "]";
+ }
+ EnlargeBy(1);
+ MoveElementsRight(idx + 1, idx, sz);
+ MutableBegin()[idx] = std::move(item);
+ }
+
+ /*!
+ * \brief Insert elements from iterator range at position idx
+ * \param idx The index to insert at
+ * \param first Begin of iterator
+ * \param last End of iterator
+ * \tparam IterType The type of iterator
+ * \note Caller must ensure capacity >= size + distance(first, last)
+ */
+ template <typename IterType>
+ void insert(int64_t idx, IterType first, IterType last) {
+ int64_t count = std::distance(first, last);
+ if (count == 0) return;
+ int64_t sz = TVMFFISeqCell::size;
+ if (idx < 0 || idx > sz) {
+ TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds [0, " <<
sz << "]";
+ }
+ EnlargeBy(count);
+ MoveElementsRight(idx + count, idx, sz);
+ Any* dst = MutableBegin() + idx;
+ for (; first != last; ++first, ++dst) {
+ *dst = Any(*first);
+ }
+ }
+
+ /*! \brief Reverse the elements in-place */
+ void Reverse() { std::reverse(MutableBegin(), MutableBegin() +
TVMFFISeqCell::size); }
+
+ /*!
+ * \brief Resize the sequence
+ * \param n The new size
+ * \note Caller must ensure capacity >= n when growing
+ */
+ void resize(int64_t n) {
+ if (n < 0) {
+ TVM_FFI_THROW(ValueError) << "Cannot resize to negative size";
+ }
+ int64_t old_size = TVMFFISeqCell::size;
+ if (old_size < n) {
+ EnlargeBy(n - old_size);
+ } else if (old_size > n) {
+ ShrinkBy(old_size - n);
+ }
+ }
+
+ protected:
+ /// \cond Doxygen_Suppress
+ Any* MutableBegin() const { return static_cast<Any*>(this->data); }
+
+ Any* MutableEnd() const { return MutableBegin() + TVMFFISeqCell::size; }
+
+ template <typename... Args>
+ void EmplaceInit(size_t idx, Args&&... args) {
+ Any* itr = MutableBegin() + idx;
+ new (itr) Any(std::forward<Args>(args)...);
+ }
+
+ void EnlargeBy(int64_t delta, const Any& val = Any()) {
+ Any* itr = MutableEnd();
+ while (delta-- > 0) {
+ new (itr++) Any(val);
+ ++TVMFFISeqCell::size;
+ }
+ }
+
+ void ShrinkBy(int64_t delta) {
+ Any* itr = MutableEnd();
+ while (delta-- > 0) {
+ (--itr)->Any::~Any();
+ --TVMFFISeqCell::size;
+ }
+ }
+
+ void MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) {
+ Any* begin = MutableBegin();
+ std::move(begin + src_begin, begin + src_end, begin + dst);
+ }
+
+ void MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) {
+ Any* begin = MutableBegin();
+ std::move_backward(begin + src_begin, begin + src_end, begin + dst +
(src_end - src_begin));
+ }
+ /// \endcond
+};
+
+/*!
+ * \brief CRTP base for sequence type-traits (Array, List).
+ *
+ * \tparam Derived Must expose:
+ * - `static constexpr int32_t kPrimaryTypeIndex` — the canonical FFI type
index
+ * - `static constexpr int32_t kOtherTypeIndex` — an alternative accepted
type index
+ * - `static constexpr const char* kTypeName` — human-readable name for
diagnostics
+ */
+template <typename Derived, typename SeqRef, typename T>
+struct SeqTypeTraitsBase : public ObjectRefTypeTraitsBase<SeqRef> {
+ using Base = ObjectRefTypeTraitsBase<SeqRef>;
+ using Base::CopyFromAnyViewAfterCheck;
+
+ TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
+ if (src->type_index != Derived::kPrimaryTypeIndex) return false;
+ if constexpr (std::is_same_v<T, Any>) {
+ return true;
+ } else {
+ const SeqBaseObj* n = reinterpret_cast<const SeqBaseObj*>(src->v_obj);
+ for (const Any& any_v : *n) {
+ if (!details::AnyUnsafe::CheckAnyStrict<T>(any_v)) return false;
+ }
+ return true;
+ }
+ }
+
+ TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
+ if (src->type_index != Derived::kPrimaryTypeIndex &&
+ src->type_index != Derived::kOtherTypeIndex) {
+ return TypeTraitsBase::GetMismatchTypeInfo(src);
+ }
+ if constexpr (!std::is_same_v<T, Any>) {
+ const SeqBaseObj* n = reinterpret_cast<const SeqBaseObj*>(src->v_obj);
+ for (size_t i = 0; i < n->size(); i++) {
+ const Any& any_v = n->at(static_cast<int64_t>(i));
+ if (details::AnyUnsafe::CheckAnyStrict<T>(any_v)) continue;
+ if (any_v.try_cast<T>()) continue;
+ return std::string(Derived::kTypeName) + "[index " + std::to_string(i)
+ ": " +
+ details::AnyUnsafe::GetMismatchTypeInfo<T>(any_v) + "]";
+ }
+ }
+ TVM_FFI_THROW(InternalError) << "Cannot reach here";
+ TVM_FFI_UNREACHABLE();
+ }
+
+ TVM_FFI_INLINE static std::optional<SeqRef> TryCastFromAnyView(const
TVMFFIAny* src) {
+ if (src->type_index != Derived::kPrimaryTypeIndex &&
+ src->type_index != Derived::kOtherTypeIndex) {
+ return std::nullopt;
+ }
+ const SeqBaseObj* n = reinterpret_cast<const SeqBaseObj*>(src->v_obj);
+ if constexpr (!std::is_same_v<T, Any>) {
+ bool storage_check = [&]() {
+ for (const Any& any_v : *n) {
+ if (!details::AnyUnsafe::CheckAnyStrict<T>(any_v)) return false;
+ }
+ return true;
+ }();
+ if (storage_check && src->type_index == Derived::kPrimaryTypeIndex) {
+ return CopyFromAnyViewAfterCheck(src);
+ }
+ SeqRef result;
+ result.reserve(static_cast<int64_t>(n->size()));
+ for (const Any& any_v : *n) {
+ if (auto opt_v = any_v.try_cast<T>()) {
+ result.push_back(*std::move(opt_v));
+ } else {
+ return std::nullopt;
+ }
+ }
+ return result;
+ } else {
+ if (src->type_index == Derived::kPrimaryTypeIndex) {
+ return CopyFromAnyViewAfterCheck(src);
+ }
+ SeqRef result;
+ result.reserve(static_cast<int64_t>(n->size()));
+ for (const Any& any_v : *n) {
+ result.push_back(any_v);
+ }
+ return result;
+ }
+ }
+
+ TVM_FFI_INLINE static std::string TypeStr() {
+ return std::string(Derived::kTypeName) + "<" + details::Type2Str<T>::v() +
">";
+ }
+
+ private:
+ SeqTypeTraitsBase() = default;
+ friend Derived;
+};
+
+} // namespace ffi
+} // namespace tvm
+
+#endif // TVM_FFI_CONTAINER_SEQ_BASE_H_
diff --git a/include/tvm/ffi/container/tuple.h
b/include/tvm/ffi/container/tuple.h
index e5eb3ca..79e402e 100644
--- a/include/tvm/ffi/container/tuple.h
+++ b/include/tvm/ffi/container/tuple.h
@@ -198,7 +198,7 @@ class Tuple : public ObjectRef {
ObjectPtr<ArrayObj> p = ArrayObj::Empty(sizeof...(Types));
Any* itr = p->MutableBegin();
// increase size after each new to ensure exception safety
- ((new (itr++) Any(Types()), p->size_++), ...);
+ ((new (itr++) Any(Types()), p->TVMFFISeqCell::size++), ...);
return p;
}
@@ -207,7 +207,7 @@ class Tuple : public ObjectRef {
ObjectPtr<ArrayObj> p = ArrayObj::Empty(sizeof...(Types));
Any* itr = p->MutableBegin();
// increase size after each new to ensure exception safety
- ((new (itr++) Any(Types(std::forward<UTypes>(args))), p->size_++), ...);
+ ((new (itr++) Any(Types(std::forward<UTypes>(args))),
p->TVMFFISeqCell::size++), ...);
return p;
}
@@ -220,7 +220,7 @@ class Tuple : public ObjectRef {
// increase size after each new to ensure exception safety
for (size_t i = 0; i < sizeof...(Types); ++i) {
new (itr++) Any(*read++);
- p->size_++;
+ p->TVMFFISeqCell::size++;
}
data_ = std::move(p);
}
diff --git a/include/tvm/ffi/extra/stl.h b/include/tvm/ffi/extra/stl.h
index 462e699..8d1f26a 100644
--- a/include/tvm/ffi/extra/stl.h
+++ b/include/tvm/ffi/extra/stl.h
@@ -36,7 +36,9 @@
#include <tvm/ffi/base_details.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/seq_base.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/object.h>
#include <tvm/ffi/type_traits.h>
@@ -124,7 +126,8 @@ struct TypeTraits<details::ListTemplate> : public
details::STLTypeTrait {
// increase size after each new to ensure exception safety
std::apply(
[&](auto&&... elems) {
- ((::new (dst++) Any(std::forward<decltype(elems)>(elems)),
array->size_++), ...);
+ ((::new (dst++) Any(std::forward<decltype(elems)>(elems)),
array->TVMFFISeqCell::size++),
+ ...);
},
std::forward<Tuple>(src));
return array;
@@ -137,7 +140,7 @@ struct TypeTraits<details::ListTemplate> : public
details::STLTypeTrait {
// increase size after each new to ensure exception safety
for (std::size_t i = 0; i < size; ++i) {
::new (dst++) Any(*(src++));
- array->size_++;
+ array->TVMFFISeqCell::size++;
}
return array;
}
@@ -206,7 +209,7 @@ struct TypeTraits<std::array<T, Nm>> : public
TypeTraits<details::ListTemplate>
TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
if (src->type_index != TypeIndex::kTVMFFIArray) return false;
const ArrayObj& n = *reinterpret_cast<const ArrayObj*>(src->v_obj);
- return n.size_ == Nm;
+ return n.TVMFFISeqCell::size == Nm;
}
public:
@@ -251,7 +254,7 @@ struct TypeTraits<std::vector<T>> : public
TypeTraits<details::ListTemplate> {
using Self = std::vector<T>;
TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
- return src->type_index == TypeIndex::kTVMFFIArray;
+ return src->type_index == TypeIndex::kTVMFFIArray || src->type_index ==
TypeIndex::kTVMFFIList;
}
public:
@@ -266,13 +269,12 @@ struct TypeTraits<std::vector<T>> : public
TypeTraits<details::ListTemplate> {
TVM_FFI_INLINE static std::optional<Self> TryCastFromAnyView(const
TVMFFIAny* src) {
if (!CheckAnyFast(src)) return std::nullopt;
try {
- auto array = CopyFromAnyImpl<ArrayObj>(src);
- auto begin = array->MutableBegin();
+ const SeqBaseObj* seq = reinterpret_cast<const SeqBaseObj*>(src->v_obj);
auto result = Self{};
- int64_t length = array->size_;
+ int64_t length = static_cast<int64_t>(seq->size());
result.reserve(length);
for (int64_t i = 0; i < length; ++i) {
- result.emplace_back(ConstructFromAny<T>(begin[i]));
+ result.emplace_back(ConstructFromAny<T>(seq->at(i)));
}
return result;
} catch (const details::STLTypeMismatch&) {
@@ -466,7 +468,7 @@ struct TypeTraits<std::tuple<Args...>> : public
TypeTraits<details::ListTemplate
TVM_FFI_INLINE static bool CheckAnyFast(const TVMFFIAny* src) {
if (src->type_index != TypeIndex::kTVMFFIArray) return false;
const ArrayObj& n = *reinterpret_cast<const ArrayObj*>(src->v_obj);
- return n.size_ == Nm;
+ return n.TVMFFISeqCell::size == Nm;
}
template <std::size_t... Is>
@@ -488,8 +490,8 @@ struct TypeTraits<std::tuple<Args...>> : public
TypeTraits<details::ListTemplate
TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
if (src->type_index != TypeIndex::kTVMFFIArray) return false;
const ArrayObj& n = *reinterpret_cast<const ArrayObj*>(src->v_obj);
- // check static length first
- if (n.size_ != Nm) return false;
+ // check static size first
+ if (n.TVMFFISeqCell::size != Nm) return false;
// then check element type
return CheckSubTypeAux(std::make_index_sequence<Nm>{}, n);
}
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index b15363e..c97ab71 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -110,6 +110,8 @@ struct StaticTypeKey {
static constexpr const char* kTVMFFIFunction = "ffi.Function";
/*! \brief The type key for Array */
static constexpr const char* kTVMFFIArray = "ffi.Array";
+ /*! \brief The type key for List */
+ static constexpr const char* kTVMFFIList = "ffi.List";
/*! \brief The type key for Map */
static constexpr const char* kTVMFFIMap = "ffi.Map";
/*! \brief The type key for Module */
diff --git a/include/tvm/ffi/tvm_ffi.h b/include/tvm/ffi/tvm_ffi.h
index 9d0c1d0..b55350d 100644
--- a/include/tvm/ffi/tvm_ffi.h
+++ b/include/tvm/ffi/tvm_ffi.h
@@ -33,6 +33,7 @@
#include <tvm/ffi/cast.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/container_details.h>
+#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/container/tensor.h>
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 87e6ef2..21eca3b 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -51,7 +51,7 @@ from ._convert import convert
from .error import register_error
from ._tensor import Device, device, DLDeviceType
from ._tensor import from_dlpack, Tensor, Shape
-from .container import Array, Map
+from .container import Array, List, Map
from .module import Module, system_lib, load_module
from .stream import StreamContext, get_raw_stream, use_raw_stream,
use_torch_stream
from . import serialization
@@ -97,6 +97,7 @@ __all__ = [
"DLDeviceType",
"Device",
"Function",
+ "List",
"Map",
"Module",
"Object",
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index 9678fd0..a76188d 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -47,6 +47,19 @@ if TYPE_CHECKING:
def GetFirstStructuralMismatch(_0: Any, _1: Any, _2: bool, _3: bool, /) ->
tuple[AccessPath, AccessPath] | None: ...
def GetGlobalFuncMetadata(_0: str, /) -> str: ...
def GetRegisteredTypeKeys() -> Sequence[str]: ...
+ def List(*args: Any) -> Any: ...
+ def ListAppend(_0: Sequence[Any], _1: Any, /) -> None: ...
+ def ListClear(_0: Sequence[Any], /) -> None: ...
+ def ListContains(_0: Sequence[Any], _1: Any, /) -> bool: ...
+ def ListErase(_0: Sequence[Any], _1: int, /) -> None: ...
+ def ListEraseRange(_0: Sequence[Any], _1: int, _2: int, /) -> None: ...
+ def ListGetItem(_0: Sequence[Any], _1: int, /) -> Any: ...
+ def ListInsert(_0: Sequence[Any], _1: int, _2: Any, /) -> None: ...
+ def ListPop(_0: Sequence[Any], _1: int, /) -> Any: ...
+ def ListReplaceSlice(_0: Sequence[Any], _1: int, _2: int, _3:
Sequence[Any], /) -> None: ...
+ def ListReverse(_0: Sequence[Any], /) -> None: ...
+ def ListSetItem(_0: Sequence[Any], _1: int, _2: Any, /) -> None: ...
+ def ListSize(_0: Sequence[Any], /) -> int: ...
def MakeObjectFromPackedArgs(*args: Any) -> Any: ...
def Map(*args: Any) -> Any: ...
def MapCount(_0: Mapping[Any, Any], _1: Any, /) -> int: ...
@@ -93,6 +106,19 @@ __all__ = [
"GetFirstStructuralMismatch",
"GetGlobalFuncMetadata",
"GetRegisteredTypeKeys",
+ "List",
+ "ListAppend",
+ "ListClear",
+ "ListContains",
+ "ListErase",
+ "ListEraseRange",
+ "ListGetItem",
+ "ListInsert",
+ "ListPop",
+ "ListReplaceSlice",
+ "ListReverse",
+ "ListSetItem",
+ "ListSize",
"MakeObjectFromPackedArgs",
"Map",
"MapCount",
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 25cdfba..1b55fc8 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -42,6 +42,7 @@ if sys.version_info >= (3, 9):
Iterable,
Iterator,
Mapping,
+ MutableSequence,
Sequence,
)
from collections.abc import (
@@ -60,6 +61,7 @@ else: # Python 3.8
Iterable,
Iterator,
Mapping,
+ MutableSequence,
Sequence,
)
from typing import (
@@ -69,7 +71,7 @@ else: # Python 3.8
ValuesView as ValuesViewBase,
)
-__all__ = ["Array", "Map"]
+__all__ = ["Array", "List", "Map"]
T = TypeVar("T")
@@ -112,16 +114,21 @@ def getitem_helper(
start, stop, step = idx.indices(length)
return [elem_getter(obj, i) for i in range(start, stop, step)]
+ index = normalize_index(length, idx)
+ return elem_getter(obj, index)
+
+
+def normalize_index(length: int, idx: SupportsIndex) -> int:
+ """Normalize and bounds-check a Python index."""
try:
index = operator.index(idx)
except TypeError as exc: # pragma: no cover - defensive, matches list
behaviour
raise TypeError(f"indices must be integers or slices, not
{type(idx).__name__}") from exc
-
if index < -length or index >= length:
raise IndexError(f"Index out of range. size: {length}, got index
{index}")
if index < 0:
index += length
- return elem_getter(obj, index)
+ return index
@register_object("ffi.Array")
@@ -209,6 +216,152 @@ class Array(core.Object, Sequence[T]):
return type(self)(itertools.chain(other, self))
+@register_object("ffi.List")
+class List(core.Object, MutableSequence[T]):
+ """Mutable list container that represents a mutable sequence in the FFI."""
+
+ # tvm-ffi-stubgen(begin): object/ffi.List
+ # fmt: off
+ # fmt: on
+ # tvm-ffi-stubgen(end)
+
+ def __init__(self, input_list: Iterable[T] = ()) -> None:
+ """Construct a List from a Python sequence."""
+ self.__init_handle_by_constructor__(_ffi_api.List, *input_list)
+
+ @overload
+ def __getitem__(self, idx: SupportsIndex, /) -> T: ...
+
+ @overload
+ def __getitem__(self, idx: slice, /) -> list[T]: ...
+
+ def __getitem__(self, idx: SupportsIndex | slice, /) -> T | list[T]: #
ty: ignore[invalid-method-override]
+ """Return one element or a list for a slice."""
+ length = len(self)
+ return getitem_helper(self, _ffi_api.ListGetItem, length, idx)
+
+ @overload
+ def __setitem__(self, index: SupportsIndex, value: T) -> None: ...
+
+ @overload
+ def __setitem__(self, index: slice[int | None], value: Iterable[T]) ->
None: ...
+
+ def __setitem__(self, index: SupportsIndex | slice[int | None], value: T |
Iterable[T]) -> None:
+ """Set one element or assign a slice."""
+ if isinstance(index, slice):
+ replacement = list(cast(Iterable[T], value))
+ length = len(self)
+ start, stop, step = index.indices(length)
+ if step != 1:
+ target_indices = list(range(start, stop, step))
+ if len(replacement) != len(target_indices):
+ raise ValueError(
+ "attempt to assign sequence of size "
+ f"{len(replacement)} to extended slice of size
{len(target_indices)}"
+ )
+ for i, item in zip(target_indices, replacement):
+ _ffi_api.ListSetItem(self, i, item)
+ return
+ stop = max(stop, start)
+ _ffi_api.ListReplaceSlice(self, start, stop,
type(self)(replacement))
+ return
+
+ normalized_index = normalize_index(len(self), index)
+ _ffi_api.ListSetItem(self, normalized_index, cast(T, value))
+
+ @overload
+ def __delitem__(self, index: SupportsIndex) -> None: ...
+
+ @overload
+ def __delitem__(self, index: slice[int | None]) -> None: ...
+
+ def __delitem__(self, index: SupportsIndex | slice[int | None]) -> None:
+ """Delete one element or a slice."""
+ if isinstance(index, slice):
+ length = len(self)
+ start, stop, step = index.indices(length)
+ if step == 1:
+ stop = max(stop, start)
+ _ffi_api.ListEraseRange(self, start, stop)
+ else:
+ # Delete indices from high to low so that earlier deletions
+ # do not shift the positions of later ones.
+ indices = (
+ reversed(range(start, stop, step)) if step > 0 else
range(start, stop, step)
+ )
+ for i in indices:
+ _ffi_api.ListErase(self, i)
+ return
+ normalized_index = normalize_index(len(self), index)
+ _ffi_api.ListErase(self, normalized_index)
+
+ def insert(self, index: int, value: T) -> None:
+ """Insert value before index."""
+ length = len(self)
+ if index < 0:
+ index = max(0, index + length)
+ else:
+ index = min(index, length)
+ _ffi_api.ListInsert(self, index, value)
+
+ def append(self, value: T) -> None:
+ """Append one value to the tail."""
+ _ffi_api.ListAppend(self, value)
+
+ def clear(self) -> None:
+ """Remove all elements from the list."""
+ _ffi_api.ListClear(self)
+
+ def reverse(self) -> None:
+ """Reverse the list in-place."""
+ _ffi_api.ListReverse(self)
+
+ def pop(self, index: int = -1) -> T:
+ """Remove and return item at index (default last)."""
+ length = len(self)
+ if length == 0:
+ raise IndexError("pop from empty list")
+ normalized_index = normalize_index(length, index)
+ return cast(T, _ffi_api.ListPop(self, normalized_index))
+
+ def extend(self, values: Iterable[T]) -> None:
+ """Append elements from an iterable."""
+ end = len(self)
+ self[end:end] = values
+
+ def __len__(self) -> int:
+ """Return the number of elements in the list."""
+ return _ffi_api.ListSize(self)
+
+ def __iter__(self) -> Iterator[T]:
+ """Iterate over the elements in the list."""
+ length = len(self)
+ for i in range(length):
+ yield cast(T, _ffi_api.ListGetItem(self, i))
+
+ def __repr__(self) -> str:
+ """Return a string representation of the list."""
+ if self.__chandle__() == 0:
+ return type(self).__name__ + "(chandle=None)"
+ return "[" + ", ".join([x.__repr__() for x in self]) + "]"
+
+ def __contains__(self, value: object) -> bool:
+ """Check if the list contains a value."""
+ return _ffi_api.ListContains(self, value)
+
+ def __bool__(self) -> bool:
+ """Return True if the list is non-empty."""
+ return len(self) > 0
+
+ def __add__(self, other: Iterable[T]) -> List[T]:
+ """Concatenate two lists."""
+ return type(self)(itertools.chain(self, other))
+
+ def __radd__(self, other: Iterable[T]) -> List[T]:
+ """Concatenate two lists."""
+ return type(self)(itertools.chain(other, self))
+
+
class KeysView(KeysViewBase[K]):
"""Helper class to return keys view."""
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 9e7ff87..a07e850 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -151,6 +151,7 @@ cdef extern from "tvm/ffi/c_api.h":
kTVMFFIMap = 72
kTVMFFIModule = 73
kTVMFFIOpaquePyObject = 74
+ kTVMFFIList = 75
ctypedef void* TVMFFIObjectHandle
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 870d052..d0e98a4 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -73,6 +73,7 @@ _TYPE_SCHEMA_ORIGIN_CONVERTER = {
"Tuple": "tuple",
"ffi.Function": "Callable",
"ffi.Array": "list",
+ "ffi.List": "list",
"ffi.Map": "dict",
"ffi.OpaquePyObject": "Any",
"ffi.Object": "Object",
diff --git a/python/tvm_ffi/testing/__init__.py
b/python/tvm_ffi/testing/__init__.py
index af22210..32520cc 100644
--- a/python/tvm_ffi/testing/__init__.py
+++ b/python/tvm_ffi/testing/__init__.py
@@ -16,6 +16,7 @@
# under the License.
"""Testing utilities."""
+from ._ffi_api import * # noqa: F403
from .testing import (
TestIntPair,
TestObjectBase,
diff --git a/python/tvm_ffi/testing/_ffi_api.py
b/python/tvm_ffi/testing/_ffi_api.py
index 29453ba..8bd11c6 100644
--- a/python/tvm_ffi/testing/_ffi_api.py
+++ b/python/tvm_ffi/testing/_ffi_api.py
@@ -61,6 +61,9 @@ if TYPE_CHECKING:
def schema_id_func(_0: Callable[..., Any], /) -> Callable[..., Any]: ...
def schema_id_func_typed(_0: Callable[[int, float, Callable[..., Any]],
None], /) -> Callable[[int, float, Callable[..., Any]], None]: ...
def schema_id_int(_0: int, /) -> int: ...
+ def schema_id_list_int(_0: Sequence[int], /) -> Sequence[int]: ...
+ def schema_id_list_obj(_0: Sequence[Object], /) -> Sequence[Object]: ...
+ def schema_id_list_str(_0: Sequence[str], /) -> Sequence[str]: ...
def schema_id_map(_0: Mapping[Any, Any], /) -> Mapping[Any, Any]: ...
def schema_id_map_str_int(_0: Mapping[str, int], /) -> Mapping[str, int]:
...
def schema_id_map_str_obj(_0: Mapping[str, Object], /) -> Mapping[str,
Object]: ...
@@ -110,6 +113,9 @@ __all__ = [
"schema_id_func",
"schema_id_func_typed",
"schema_id_int",
+ "schema_id_list_int",
+ "schema_id_list_obj",
+ "schema_id_list_str",
"schema_id_map",
"schema_id_map_str_int",
"schema_id_map_str_obj",
diff --git a/src/ffi/container.cc b/src/ffi/container.cc
index dd1d004..f3171a8 100644
--- a/src/ffi/container.cc
+++ b/src/ffi/container.cc
@@ -21,6 +21,7 @@
* \file src/ffi/container.cc
*/
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
@@ -76,6 +77,62 @@ TVM_FFI_STATIC_INIT_BLOCK() {
return std::any_of(n->begin(), n->end(),
[&](const Any& elem) { return eq(elem, value);
});
})
+ .def_packed("ffi.List",
+ [](ffi::PackedArgs args, Any* ret) {
+ *ret = List<Any>(args.data(), args.data() + args.size());
+ })
+ .def("ffi.ListGetItem", [](const ffi::ListObj* n, int64_t i) -> Any {
return n->at(i); })
+ .def("ffi.ListSetItem",
+ [](ffi::List<Any> n, int64_t i, Any value) -> void { n.Set(i,
std::move(value)); })
+ .def("ffi.ListSize",
+ [](const ffi::ListObj* n) -> int64_t { return
static_cast<int64_t>(n->size()); })
+ .def("ffi.ListContains",
+ [](const ffi::ListObj* n, const Any& value) -> bool {
+ AnyEqual eq;
+ return std::any_of(n->begin(), n->end(),
+ [&](const Any& elem) { return eq(elem, value);
});
+ })
+ .def("ffi.ListAppend", [](ffi::List<Any> n, const Any& value) -> void {
n.push_back(value); })
+ .def("ffi.ListInsert",
+ [](ffi::List<Any> n, int64_t i, const Any& value) -> void {
+ n.insert(n.begin() + i, value);
+ })
+ .def("ffi.ListPop",
+ [](const ffi::List<Any>& n, int64_t i) -> Any {
+ ffi::ListObj* obj = n.GetListObj();
+ Any value = obj->at(i);
+ obj->erase(i);
+ return value;
+ })
+ .def("ffi.ListErase",
+ [](const ffi::List<Any>& n, int64_t i) -> void {
n.GetListObj()->erase(i); })
+ .def("ffi.ListEraseRange",
+ [](const ffi::List<Any>& n, int64_t start, int64_t stop) -> void {
+ n.GetListObj()->erase(start, stop);
+ })
+ .def("ffi.ListReplaceSlice",
+ [](ffi::List<Any> n, int64_t start, int64_t stop,
+ const ffi::List<Any>& replacement) -> void {
+ // Snapshot replacement before erasing in case n and replacement
alias the same object.
+ ffi::List<Any> rep_copy = n.same_as(replacement)
+ ?
ffi::List<Any>(replacement.begin(), replacement.end())
+ : replacement;
+ n.GetListObj()->erase(start, stop);
+ if (rep_copy.empty()) {
+ return;
+ }
+ const ffi::ListObj* replacement_obj = rep_copy.GetListObj();
+ TVM_FFI_ICHECK(replacement_obj != nullptr);
+ n.insert(n.begin() + start, replacement_obj->begin(),
replacement_obj->end());
+ })
+ .def("ffi.ListReverse",
+ [](const ffi::List<Any>& n) -> void {
+ ffi::ListObj* obj = n.GetListObj();
+ if (obj != nullptr) {
+ obj->SeqBaseObj::Reverse();
+ }
+ })
+ .def("ffi.ListClear", [](ffi::List<Any> n) -> void { n.clear(); })
.def_packed("ffi.Map",
[](ffi::PackedArgs args, Any* ret) {
TVM_FFI_ICHECK_EQ(args.size() % 2, 0);
diff --git a/src/ffi/extra/json_writer.cc b/src/ffi/extra/json_writer.cc
index 986ed49..8abb17b 100644
--- a/src/ffi/extra/json_writer.cc
+++ b/src/ffi/extra/json_writer.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/json.h>
@@ -34,6 +35,7 @@
#include <cstdint>
#include <limits>
#include <string>
+#include <unordered_set>
#include <utility>
namespace tvm {
@@ -125,6 +127,10 @@ class JSONWriter {
WriteArray(details::AnyUnsafe::CopyFromAnyViewAfterCheck<json::Array>(value));
break;
}
+ case TypeIndex::kTVMFFIList: {
+
WriteList(details::AnyUnsafe::CopyFromAnyViewAfterCheck<ffi::List<Any>>(value));
+ break;
+ }
case TypeIndex::kTVMFFIMap: {
WriteObject(details::AnyUnsafe::CopyFromAnyViewAfterCheck<json::Object>(value));
break;
@@ -184,7 +190,20 @@ class JSONWriter {
std::copy(escaped.data(), escaped.data() + escaped.size(), out_iter_);
}
- void WriteArray(const json::Array& value) {
+ void WriteArray(const json::Array& value) { WriteSequence(value); }
+
+ void WriteList(const ffi::List<Any>& value) {
+ const void* ptr = static_cast<const void*>(value.get());
+ if (active_lists_.count(ptr)) {
+ TVM_FFI_THROW(ValueError) << "Cycle detected: List contains itself";
+ }
+ active_lists_.insert(ptr);
+ WriteSequence(value);
+ active_lists_.erase(ptr);
+ }
+
+ template <typename SeqType>
+ void WriteSequence(const SeqType& value) {
*out_iter_++ = '[';
if (indent_ != 0) {
total_indent_ += indent_;
@@ -248,6 +267,7 @@ class JSONWriter {
int total_indent_ = 0;
std::string result_;
std::back_insert_iterator<std::string> out_iter_;
+ std::unordered_set<const void*> active_lists_;
};
String Stringify(const json::Value& value, Optional<int> indent) {
diff --git a/src/ffi/extra/serialization.cc b/src/ffi/extra/serialization.cc
index c456204..1e5e98c 100644
--- a/src/ffi/extra/serialization.cc
+++ b/src/ffi/extra/serialization.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/dtype.h>
@@ -33,6 +34,9 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
+#include <unordered_map>
+#include <unordered_set>
+
namespace tvm {
namespace ffi {
@@ -111,7 +115,19 @@ class ObjectGraphSerializer {
case TypeIndex::kTVMFFIArray: {
Array<Any> array =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<Array<Any>>(value);
node.Set("type", ffi::StaticTypeKey::kTVMFFIArray);
- node.Set("data", CreateArrayData(array));
+ node.Set("data", CreateSequenceData(array));
+ break;
+ }
+ case TypeIndex::kTVMFFIList: {
+ List<Any> list =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<List<Any>>(value);
+ node.Set("type", ffi::StaticTypeKey::kTVMFFIList);
+ const void* list_ptr = static_cast<const void*>(list.get());
+ if (!active_lists_.insert(list_ptr).second) {
+ TVM_FFI_THROW(ValueError)
+ << "Cycle detected during serialization: a List contains itself";
+ }
+ node.Set("data", CreateSequenceData(list));
+ active_lists_.erase(list_ptr);
break;
}
case TypeIndex::kTVMFFIMap: {
@@ -139,11 +155,12 @@ class ObjectGraphSerializer {
}
int64_t node_index = static_cast<int64_t>(nodes_.size());
nodes_.push_back(node);
- node_index_map_.Set(value, node_index);
+ node_index_map_.emplace(value, node_index);
return node_index;
}
- json::Array CreateArrayData(const Array<Any>& value) {
+ template <typename SeqType>
+ json::Array CreateSequenceData(const SeqType& value) {
json::Array data;
data.reserve(static_cast<int64_t>(value.size()));
for (const Any& item : value) {
@@ -220,9 +237,11 @@ class ObjectGraphSerializer {
}
// maps the original value to the index of the node in the nodes_ array
- Map<Any, int64_t> node_index_map_;
+ std::unordered_map<Any, int64_t, AnyHash, AnyEqual> node_index_map_;
// records nodes that are serialized
json::Array nodes_;
+ // tracks List nodes currently being serialized (for cycle detection)
+ std::unordered_set<const void*> active_lists_;
};
json::Value ToJSONGraph(const Any& value, const Any& metadata) {
@@ -246,7 +265,7 @@ class ObjectGraphDeserializer {
return decoded_nodes_[node_index];
}
// now decode the node
- Any value = DecodeNode(nodes_[node_index].cast<json::Object>());
+ Any value = DecodeNode(node_index,
nodes_[node_index].cast<json::Object>());
decoded_nodes_[node_index] = value;
if (value == nullptr) {
decoded_null_index_ = node_index;
@@ -255,7 +274,7 @@ class ObjectGraphDeserializer {
}
private:
- Any DecodeNode(const json::Object& node) {
+ Any DecodeNode(int64_t node_index, const json::Object& node) {
String type_key = node["type"].cast<String>();
TVMFFIByteArray type_key_arr{type_key.data(), type_key.length()};
int32_t type_index;
@@ -291,7 +310,10 @@ class ObjectGraphDeserializer {
return DecodeMapData(node["data"].cast<json::Array>());
}
case TypeIndex::kTVMFFIArray: {
- return DecodeArrayData(node["data"].cast<json::Array>());
+ return
DecodeSequenceData<Array<Any>>(node["data"].cast<json::Array>());
+ }
+ case TypeIndex::kTVMFFIList: {
+ return DecodeSequenceData<List<Any>>(node["data"].cast<json::Array>());
}
case TypeIndex::kTVMFFIShape: {
Array<int64_t> data = node["data"].cast<Array<int64_t>>();
@@ -303,13 +325,14 @@ class ObjectGraphDeserializer {
}
}
- Array<Any> DecodeArrayData(const json::Array& data) {
- Array<Any> array;
- array.reserve(static_cast<int64_t>(data.size()));
+ template <typename SeqType>
+ SeqType DecodeSequenceData(const json::Array& data) {
+ SeqType sequence;
+ sequence.reserve(static_cast<int64_t>(data.size()));
for (const auto& elem : data) {
- array.push_back(GetOrDecodeNode(elem.cast<int64_t>()));
+ sequence.push_back(GetOrDecodeNode(elem.cast<int64_t>()));
}
- return array;
+ return sequence;
}
Map<Any, Any> DecodeMapData(const json::Array& data) {
diff --git a/src/ffi/extra/structural_equal.cc
b/src/ffi/extra/structural_equal.cc
index b236828..2fa05f2 100644
--- a/src/ffi/extra/structural_equal.cc
+++ b/src/ffi/extra/structural_equal.cc
@@ -22,6 +22,7 @@
* \brief Structural equal implementation.
*/
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/container/tensor.h>
@@ -31,6 +32,7 @@
#include <cmath>
#include <unordered_map>
+#include <utility>
namespace tvm {
namespace ffi {
@@ -103,6 +105,10 @@ class StructEqualHandler {
return
CompareArray(AnyUnsafe::MoveFromAnyAfterCheck<Array<Any>>(std::move(lhs)),
AnyUnsafe::MoveFromAnyAfterCheck<Array<Any>>(std::move(rhs)));
}
+ case TypeIndex::kTVMFFIList: {
+ return
CompareList(AnyUnsafe::MoveFromAnyAfterCheck<List<Any>>(std::move(lhs)),
+
AnyUnsafe::MoveFromAnyAfterCheck<List<Any>>(std::move(rhs)));
+ }
case TypeIndex::kTVMFFIMap: {
return CompareMap(AnyUnsafe::MoveFromAnyAfterCheck<Map<Any,
Any>>(std::move(lhs)),
AnyUnsafe::MoveFromAnyAfterCheck<Map<Any,
Any>>(std::move(rhs)));
@@ -302,6 +308,17 @@ class StructEqualHandler {
// NOLINTNEXTLINE(performance-unnecessary-value-param)
bool CompareArray(ffi::Array<Any> lhs, ffi::Array<Any> rhs) {
+ return CompareSequence(std::move(lhs), std::move(rhs));
+ }
+
+ // NOLINTNEXTLINE(performance-unnecessary-value-param)
+ bool CompareList(ffi::List<Any> lhs, ffi::List<Any> rhs) {
+ return CompareSequence(std::move(lhs), std::move(rhs));
+ }
+
+ template <typename SeqType>
+ // NOLINTNEXTLINE(performance-unnecessary-value-param)
+ bool CompareSequence(SeqType lhs, SeqType rhs) {
if (lhs.size() != rhs.size()) {
// fast path, size mismatch, and there is no path tracing
// return false since we don't need informative error message
diff --git a/src/ffi/extra/structural_hash.cc b/src/ffi/extra/structural_hash.cc
index 5bb9eb1..aed6abe 100644
--- a/src/ffi/extra/structural_hash.cc
+++ b/src/ffi/extra/structural_hash.cc
@@ -22,6 +22,7 @@
* \brief Structural equal implementation.
*/
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/container/tensor.h>
@@ -78,6 +79,9 @@ class StructuralHashHandler {
case TypeIndex::kTVMFFIArray: {
return
HashArray(AnyUnsafe::MoveFromAnyAfterCheck<Array<Any>>(std::move(src)));
}
+ case TypeIndex::kTVMFFIList: {
+ return
HashList(AnyUnsafe::MoveFromAnyAfterCheck<List<Any>>(std::move(src)));
+ }
case TypeIndex::kTVMFFIMap: {
return HashMap(AnyUnsafe::MoveFromAnyAfterCheck<Map<Any,
Any>>(std::move(src)));
}
@@ -185,9 +189,16 @@ class StructuralHashHandler {
}
// NOLINTNEXTLINE(performance-unnecessary-value-param)
- uint64_t HashArray(Array<Any> arr) {
- uint64_t hash_value = details::StableHashCombine(arr->GetTypeKeyHash(),
arr.size());
- for (const auto& elem : arr) {
+ uint64_t HashArray(Array<Any> arr) { return HashSequence(std::move(arr)); }
+
+ // NOLINTNEXTLINE(performance-unnecessary-value-param)
+ uint64_t HashList(List<Any> list) { return HashSequence(std::move(list)); }
+
+ template <typename SeqType>
+ // NOLINTNEXTLINE(performance-unnecessary-value-param)
+ uint64_t HashSequence(SeqType seq) {
+ uint64_t hash_value = details::StableHashCombine(seq->GetTypeKeyHash(),
seq.size());
+ for (const auto& elem : seq) {
hash_value = details::StableHashCombine(hash_value, HashAny(elem));
}
return hash_value;
diff --git a/src/ffi/object.cc b/src/ffi/object.cc
index e8a232d..8c5c137 100644
--- a/src/ffi/object.cc
+++ b/src/ffi/object.cc
@@ -363,6 +363,7 @@ class TypeTable {
ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFIShape,
TypeIndex::kTVMFFIShape);
ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFITensor,
TypeIndex::kTVMFFITensor);
ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFIArray,
TypeIndex::kTVMFFIArray);
+ ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFIList,
TypeIndex::kTVMFFIList);
ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFIMap,
TypeIndex::kTVMFFIMap);
ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFIModule,
TypeIndex::kTVMFFIModule);
ReserveDepthOneObjectTypeIndex(StaticTypeKey::kTVMFFIOpaquePyObject,
diff --git a/src/ffi/testing/testing.cc b/src/ffi/testing/testing.cc
index 0df7f1e..f6d1ff5 100644
--- a/src/ffi/testing/testing.cc
+++ b/src/ffi/testing/testing.cc
@@ -22,6 +22,7 @@
#include <dlpack/dlpack.h>
#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/container/variant.h>
@@ -354,6 +355,11 @@ Variant<int64_t, String, Array<int64_t>>
schema_variant_mix(
return v;
}
+// List types
+List<int64_t> schema_id_list_int(List<int64_t> lst) { return lst; }
+List<String> schema_id_list_str(List<String> lst) { return lst; }
+List<ObjectRef> schema_id_list_obj(List<ObjectRef> lst) { return lst; }
+
// Complex nested types
Map<String, Array<int64_t>> schema_arr_map_opt(const Array<Optional<int64_t>>&
arr,
Map<String, Array<int64_t>> mp,
@@ -516,6 +522,9 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def("testing.schema_id_arr_str", schema_test_impl::schema_id_arr_str)
.def("testing.schema_id_arr_obj", schema_test_impl::schema_id_arr_obj)
.def("testing.schema_id_arr", schema_test_impl::schema_id_arr)
+ .def("testing.schema_id_list_int", schema_test_impl::schema_id_list_int)
+ .def("testing.schema_id_list_str", schema_test_impl::schema_id_list_str)
+ .def("testing.schema_id_list_obj", schema_test_impl::schema_id_list_obj)
.def("testing.schema_id_map_str_int",
schema_test_impl::schema_id_map_str_int)
.def("testing.schema_id_map_str_str",
schema_test_impl::schema_id_map_str_str)
.def("testing.schema_id_map_str_obj",
schema_test_impl::schema_id_map_str_obj)
diff --git a/tests/cpp/extra/test_serialization.cc
b/tests/cpp/extra/test_serialization.cc
index 9d18e6a..1d9d483 100644
--- a/tests/cpp/extra/test_serialization.cc
+++ b/tests/cpp/extra/test_serialization.cc
@@ -18,6 +18,7 @@
*/
#include <gtest/gtest.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/container/shape.h>
#include <tvm/ffi/dtype.h>
@@ -354,6 +355,46 @@ TEST(Serialization, AttachMetadata) {
EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected), value));
}
+TEST(Serialization, ListBasic) {
+ // Test empty list
+ List<Any> empty_list;
+ json::Object expected_empty = json::Object{
+ {"root_index", 0},
+ {"nodes", json::Array{json::Object{{"type", "ffi.List"}, {"data",
json::Array{}}}}}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(empty_list), expected_empty));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_empty), empty_list));
+
+ // Test single element list
+ List<Any> single_list;
+ single_list.push_back(Any(42));
+ json::Object expected_single =
+ json::Object{{"root_index", 1},
+ {"nodes", json::Array{
+ json::Object{{"type", "int"}, {"data",
static_cast<int64_t>(42)}},
+ json::Object{{"type", "ffi.List"}, {"data",
json::Array{0}}},
+ }}};
+ EXPECT_TRUE(StructuralEqual()(ToJSONGraph(single_list), expected_single));
+ EXPECT_TRUE(StructuralEqual()(FromJSONGraph(expected_single), single_list));
+}
+
+TEST(Serialization, ListRoundTrip) {
+ // Test roundtrip for nested list
+ List<Any> nested;
+ nested.push_back(1);
+ nested.push_back(String("hello"));
+ nested.push_back(true);
+ json::Value serialized = ToJSONGraph(nested);
+ Any deserialized = FromJSONGraph(serialized);
+ EXPECT_TRUE(StructuralEqual()(deserialized, nested));
+}
+
+TEST(Serialization, DISABLED_ListCycleDetection) {
+ List<Any> lst;
+ lst.push_back(42);
+ lst.push_back(lst); // creates a cycle via shared mutable reference
+ EXPECT_ANY_THROW(ToJSONGraph(lst));
+}
+
TEST(Serialization, ShuffleNodeOrder) {
// the FromJSONGraph is agnostic to the node order
// so we can shuffle the node order as it reads nodes lazily
diff --git a/tests/cpp/extra/test_structural_equal_hash.cc
b/tests/cpp/extra/test_structural_equal_hash.cc
index a05c50c..a768319 100644
--- a/tests/cpp/extra/test_structural_equal_hash.cc
+++ b/tests/cpp/extra/test_structural_equal_hash.cc
@@ -19,6 +19,7 @@
*/
#include <gtest/gtest.h>
#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/extra/structural_equal.h>
#include <tvm/ffi/extra/structural_hash.h>
@@ -175,4 +176,44 @@ TEST(StructuralEqualHash, CustomTreeNode) {
EXPECT_TRUE(StructuralEqual()(diff_fa_fc, expected_diff_fa_fc));
}
+TEST(StructuralEqualHash, List) {
+ List<int> a = {1, 2, 3};
+ List<int> b = {1, 2, 3};
+ EXPECT_TRUE(StructuralEqual()(a, b));
+ EXPECT_EQ(StructuralHash()(a), StructuralHash()(b));
+
+ List<int> c = {1, 3};
+ EXPECT_FALSE(StructuralEqual()(a, c));
+ EXPECT_NE(StructuralHash()(a), StructuralHash()(c));
+}
+
+TEST(StructuralEqualHash, ListVsArrayDifferentType) {
+ Array<int> arr = {1, 2, 3};
+ List<int> lst = {1, 2, 3};
+ // Different type_index => not equal
+ EXPECT_FALSE(StructuralEqual()(arr, lst));
+ // Different type_key_hash => different hash (very likely)
+ EXPECT_NE(StructuralHash()(arr), StructuralHash()(lst));
+}
+
+TEST(StructuralEqualHash, DISABLED_ListCycleDetection) {
+ List<Any> lst;
+ lst.push_back(42);
+ lst.push_back(lst); // creates a cycle
+ EXPECT_ANY_THROW(StructuralHash()(lst));
+ EXPECT_ANY_THROW(StructuralEqual()(lst, lst));
+}
+
+TEST(StructuralEqualHash, ArraySelfInsertProducesSnapshot) {
+ Array<Any> arr;
+ arr.push_back(arr);
+
+ Array<Any> snapshot = arr[0].cast<Array<Any>>();
+ EXPECT_TRUE(snapshot.empty());
+ EXPECT_FALSE(snapshot.same_as(arr));
+
+ EXPECT_TRUE(StructuralEqual()(arr, arr));
+ EXPECT_EQ(StructuralHash()(arr), StructuralHash()(arr));
+}
+
} // namespace
diff --git a/tests/cpp/test_list.cc b/tests/cpp/test_list.cc
new file mode 100644
index 0000000..c17dc53
--- /dev/null
+++ b/tests/cpp/test_list.cc
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/list.h>
+
+#include <limits>
+#include <vector>
+
+namespace {
+
+using namespace tvm::ffi;
+
+TEST(List, Basic) {
+ List<int> list = {11, 12};
+ EXPECT_EQ(list.size(), 2U);
+ EXPECT_EQ(list[0], 11);
+ EXPECT_EQ(list[1], 12);
+}
+
+TEST(List, SharedMutation) {
+ List<int> list = {1, 2};
+ List<int> alias = list;
+
+ EXPECT_TRUE(list.same_as(alias));
+ list.Set(1, 3);
+ EXPECT_EQ(alias[1], 3);
+
+ alias.push_back(4);
+ EXPECT_EQ(list.size(), 3U);
+ EXPECT_EQ(list[2], 4);
+}
+
+TEST(List, AssignmentOperators) {
+ List<int> a = {1, 2};
+ List<int> b;
+ b = a;
+ EXPECT_TRUE(a.same_as(b));
+
+ b.Set(0, 5);
+ EXPECT_EQ(a[0], 5);
+
+ List<int> c;
+ c = std::move(b);
+ EXPECT_TRUE(c.same_as(a));
+
+ List<Any> d;
+ d = c;
+ EXPECT_EQ(d.size(), c.size());
+}
+
+TEST(List, PushPopInsertErase) {
+ List<int> list;
+ std::vector<int> vector;
+
+ for (int i = 0; i < 10; ++i) {
+ list.push_back(i);
+ vector.push_back(i);
+ }
+ EXPECT_EQ(list.size(), vector.size());
+ for (size_t i = 0; i < vector.size(); ++i) {
+ EXPECT_EQ(list[static_cast<int64_t>(i)], vector[i]);
+ }
+
+ list.insert(list.begin() + 5, 100);
+ vector.insert(vector.begin() + 5, 100);
+ EXPECT_EQ(list.size(), vector.size());
+ for (size_t i = 0; i < vector.size(); ++i) {
+ EXPECT_EQ(list[static_cast<int64_t>(i)], vector[i]);
+ }
+
+ list.erase(list.begin() + 3, list.begin() + 7);
+ vector.erase(vector.begin() + 3, vector.begin() + 7);
+ EXPECT_EQ(list.size(), vector.size());
+ for (size_t i = 0; i < vector.size(); ++i) {
+ EXPECT_EQ(list[static_cast<int64_t>(i)], vector[i]);
+ }
+
+ list.pop_back();
+ vector.pop_back();
+ EXPECT_EQ(list.size(), vector.size());
+ for (size_t i = 0; i < vector.size(); ++i) {
+ EXPECT_EQ(list[static_cast<int64_t>(i)], vector[i]);
+ }
+}
+
+TEST(List, ReserveReallocationPreservesValues) {
+ List<int> list;
+ for (int i = 0; i < 8; ++i) {
+ list.push_back(i);
+ }
+
+ auto* before_obj = list.GetListObj();
+ size_t before_capacity = list.capacity();
+
+ int64_t reserve_target = static_cast<int64_t>(before_capacity) + 32;
+ list.reserve(reserve_target);
+
+ auto* after_obj = list.GetListObj();
+ EXPECT_EQ(before_obj, after_obj);
+ EXPECT_GE(list.capacity(), before_capacity + 32);
+ for (int i = 0; i < 8; ++i) {
+ EXPECT_EQ(list[i], i);
+ }
+
+ list.reserve(1);
+ EXPECT_GE(list.capacity(), before_capacity + 32);
+}
+
+TEST(List, AnyImplicitConversionFromArray) {
+ Array<Any> array = {1, 2.5};
+ AnyView array_view = array;
+ List<double> list = array_view.cast<List<double>>();
+
+ EXPECT_EQ(list.size(), 2U);
+ EXPECT_EQ(list[0], 1.0);
+ EXPECT_EQ(list[1], 2.5);
+ EXPECT_FALSE(list.same_as(array));
+
+ list.Set(0, 99.0);
+ EXPECT_EQ(array[0].cast<int>(), 1);
+
+ List<Any> list_any = {1, 2};
+ AnyView list_view = list_any;
+ List<Any> list_any_roundtrip = list_view.cast<List<Any>>();
+ EXPECT_TRUE(list_any_roundtrip.same_as(list_any));
+}
+
+TEST(List, AnyConvertCheck) {
+ Any any = Array<Any>{String("x"), 1};
+
+ EXPECT_THROW(
+ {
+ try {
+ [[maybe_unused]] auto value = any.cast<List<int>>();
+ } catch (const Error& error) {
+ EXPECT_EQ(error.kind(), "TypeError");
+ std::string what = error.what();
+ EXPECT_NE(what.find("Cannot convert from type `List[index 0:"),
std::string::npos);
+ EXPECT_NE(what.find("to `List<int>`"), std::string::npos);
+ throw;
+ }
+ },
+ ::tvm::ffi::Error);
+}
+
+TEST(List, AnyImplicitConversionToArray) {
+ List<int> list = {10, 20, 30};
+ AnyView list_view = list;
+ auto arr = list_view.cast<Array<int>>();
+ EXPECT_EQ(arr.size(), 3U);
+ EXPECT_EQ(arr[0], 10);
+ EXPECT_EQ(arr[1], 20);
+ EXPECT_EQ(arr[2], 30);
+ EXPECT_FALSE(arr.same_as(list));
+}
+
+TEST(List, EmptyListDestructorDoesNotCrash) {
+ {
+ List<int> empty;
+ }
+ {
+ List<int> filled = {1, 2, 3};
+ filled.clear();
+ }
+}
+
+TEST(List, SeqBaseObjPopBack) {
+ List<int> list = {10, 20, 30};
+ ListObj* obj = list.GetListObj();
+ obj->pop_back();
+ EXPECT_EQ(list.size(), 2U);
+ EXPECT_EQ(list[0], 10);
+ EXPECT_EQ(list[1], 20);
+ obj->pop_back();
+ obj->pop_back();
+ EXPECT_EQ(list.size(), 0U);
+ EXPECT_THROW(obj->pop_back(), Error);
+}
+
+TEST(List, SeqBaseObjErase) {
+ List<int> list = {10, 20, 30, 40, 50};
+ ListObj* obj = list.GetListObj();
+ // Erase single element at index 2 (value 30)
+ obj->erase(2);
+ EXPECT_EQ(list.size(), 4U);
+ EXPECT_EQ(list[0], 10);
+ EXPECT_EQ(list[1], 20);
+ EXPECT_EQ(list[2], 40);
+ EXPECT_EQ(list[3], 50);
+ // Out of bounds
+ EXPECT_THROW(obj->erase(4), Error);
+ EXPECT_THROW(obj->erase(-1), Error);
+}
+
+TEST(List, SeqBaseObjEraseRange) {
+ List<int> list = {10, 20, 30, 40, 50};
+ ListObj* obj = list.GetListObj();
+ // Erase range [1, 3) -> removes 20, 30
+ obj->erase(int64_t{1}, int64_t{3});
+ EXPECT_EQ(list.size(), 3U);
+ EXPECT_EQ(list[0], 10);
+ EXPECT_EQ(list[1], 40);
+ EXPECT_EQ(list[2], 50);
+ // No-op erase
+ obj->erase(int64_t{1}, int64_t{1});
+ EXPECT_EQ(list.size(), 3U);
+ // Invalid ranges
+ EXPECT_THROW(obj->erase(int64_t{2}, int64_t{1}), Error);
+ EXPECT_THROW(obj->erase(int64_t{-1}, int64_t{2}), Error);
+ EXPECT_THROW(obj->erase(int64_t{0}, int64_t{4}), Error);
+}
+
+TEST(List, SeqBaseObjInsert) {
+ List<int> list;
+ list.reserve(10);
+ list.push_back(10);
+ list.push_back(30);
+ ListObj* obj = list.GetListObj();
+ // Insert 20 at index 1
+ obj->insert(1, Any(int64_t{20}));
+ EXPECT_EQ(list.size(), 3U);
+ EXPECT_EQ(list[0], 10);
+ EXPECT_EQ(list[1], 20);
+ EXPECT_EQ(list[2], 30);
+ // Insert at beginning
+ obj->insert(0, Any(int64_t{5}));
+ EXPECT_EQ(list[0], 5);
+ EXPECT_EQ(list.size(), 4U);
+ // Insert at end
+ obj->insert(4, Any(int64_t{40}));
+ EXPECT_EQ(list[4], 40);
+ EXPECT_EQ(list.size(), 5U);
+ // Out of bounds
+ EXPECT_THROW(obj->insert(-1, Any(int64_t{0})), Error);
+ EXPECT_THROW(obj->insert(6, Any(int64_t{0})), Error);
+}
+
+TEST(List, SeqBaseObjResize) {
+ List<int> list = {10, 20, 30};
+ list.reserve(10);
+ ListObj* obj = list.GetListObj();
+ // Grow
+ obj->resize(5);
+ EXPECT_EQ(list.size(), 5U);
+ EXPECT_EQ(list[0], 10);
+ EXPECT_EQ(list[1], 20);
+ EXPECT_EQ(list[2], 30);
+ // Shrink
+ obj->resize(2);
+ EXPECT_EQ(list.size(), 2U);
+ EXPECT_EQ(list[0], 10);
+ EXPECT_EQ(list[1], 20);
+ // No-op
+ obj->resize(2);
+ EXPECT_EQ(list.size(), 2U);
+ // Negative size
+ EXPECT_THROW(obj->resize(-1), Error);
+}
+
+} // namespace
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index ffedb75..ec289cd 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -22,6 +22,7 @@ from typing import Any
import pytest
import tvm_ffi
+from tvm_ffi import testing
if sys.version_info >= (3, 9):
# PEP 585 generics
@@ -52,6 +53,9 @@ def test_bad_constructor_init_state() -> None:
with pytest.raises(TypeError):
tvm_ffi.Array(1) # ty: ignore[invalid-argument-type]
+ with pytest.raises(TypeError):
+ tvm_ffi.List(1) # ty: ignore[invalid-argument-type]
+
with pytest.raises(AttributeError):
tvm_ffi.Map(1) # ty: ignore[invalid-argument-type]
@@ -255,3 +259,243 @@ def test_array_bool(arr: list[Any], expected: bool) ->
None:
def test_map_bool(mapping: dict[Any, Any], expected: bool) -> None:
m = tvm_ffi.Map(mapping)
assert bool(m) is expected
+
+
+def test_list_basic() -> None:
+ lst = tvm_ffi.List([1, 2, 3])
+ assert isinstance(lst, tvm_ffi.List)
+ assert len(lst) == 3
+ assert tuple(lst) == (1, 2, 3)
+ assert lst[0] == 1
+ assert lst[-1] == 3
+ assert lst[:] == [1, 2, 3]
+ assert lst[::-1] == [3, 2, 1]
+
+
+def test_list_mutation_methods() -> None:
+ lst = tvm_ffi.List([1, 2, 3])
+ lst.append(4)
+ assert tuple(lst) == (1, 2, 3, 4)
+ lst.insert(2, 9)
+ assert tuple(lst) == (1, 2, 9, 3, 4)
+ value = lst.pop()
+ assert value == 4
+ assert tuple(lst) == (1, 2, 9, 3)
+ value = lst.pop(1)
+ assert value == 2
+ assert tuple(lst) == (1, 9, 3)
+ lst.extend([5, 6])
+ assert tuple(lst) == (1, 9, 3, 5, 6)
+ lst.clear()
+ assert tuple(lst) == ()
+ assert len(lst) == 0
+
+
+def test_list_setitem_and_delitem() -> None:
+ lst = tvm_ffi.List([0, 1, 2, 3, 4, 5])
+ lst[1] = 10
+ lst[-1] = 60
+ assert tuple(lst) == (0, 10, 2, 3, 4, 60)
+
+ del lst[2]
+ assert tuple(lst) == (0, 10, 3, 4, 60)
+ del lst[-1]
+ assert tuple(lst) == (0, 10, 3, 4)
+
+
+def test_list_slice_assignment_and_delete() -> None:
+ lst = tvm_ffi.List([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
+
+ lst[2:6] = [20, 21]
+ assert tuple(lst) == (0, 1, 20, 21, 6, 7, 8, 9)
+
+ lst[3:3] = [30, 31]
+ assert tuple(lst) == (0, 1, 20, 30, 31, 21, 6, 7, 8, 9)
+
+ lst[1:9:2] = [101, 102, 103, 104]
+ assert tuple(lst) == (0, 101, 20, 102, 31, 103, 6, 104, 8, 9)
+
+ with pytest.raises(ValueError):
+ lst[0:6:2] = [1, 2]
+
+ del lst[1:8:2]
+ assert tuple(lst) == (0, 20, 31, 6, 8, 9)
+
+ del lst[2:2]
+ assert tuple(lst) == (0, 20, 31, 6, 8, 9)
+
+ del lst[1:4]
+ assert tuple(lst) == (0, 8, 9)
+
+
+def test_list_slice_edge_cases() -> None:
+ lst = tvm_ffi.List([0, 1, 2, 3, 4])
+ lst[3:1] = [9]
+ assert tuple(lst) == (0, 1, 2, 9, 3, 4)
+
+ lst = tvm_ffi.List([0, 1, 2, 3])
+ del lst[3:1]
+ assert tuple(lst) == (0, 1, 2, 3)
+
+ lst = tvm_ffi.List([0, 1, 2, 3, 4, 5])
+ del lst[5:1:-1]
+ assert tuple(lst) == (0, 1)
+
+ lst = tvm_ffi.List([0, 1, 2, 3])
+ del lst[::-1]
+ assert tuple(lst) == ()
+
+
+def test_list_contains_bool_repr_and_concat() -> None:
+ lst = tvm_ffi.List([1, 2, 3])
+ assert 2 in lst
+ assert 5 not in lst
+ assert bool(lst) is True
+ assert str(lst) == "[1, 2, 3]"
+ assert tuple(lst.__add__([4, 5])) == (1, 2, 3, 4, 5)
+ assert tuple(lst.__radd__([0])) == (0, 1, 2, 3)
+
+ empty = tvm_ffi.List()
+ assert bool(empty) is False
+ assert str(empty) == "[]"
+
+
+def test_list_insert_index_normalization() -> None:
+ lst = tvm_ffi.List([1, 2, 3])
+ lst.insert(-100, 0)
+ assert tuple(lst) == (0, 1, 2, 3)
+ lst.insert(100, 4)
+ assert tuple(lst) == (0, 1, 2, 3, 4)
+
+
+def test_list_error_cases() -> None:
+ lst = tvm_ffi.List([1, 2, 3])
+
+ with pytest.raises(IndexError):
+ _ = lst[3]
+ with pytest.raises(IndexError):
+ _ = lst[-4]
+ with pytest.raises(IndexError):
+ lst[3] = 0
+ with pytest.raises(IndexError):
+ del lst[3]
+
+ with pytest.raises(IndexError):
+ tvm_ffi.List().pop()
+
+
+def test_list_pickle_roundtrip() -> None:
+ lst = tvm_ffi.List([1, "a", {"k": 2}])
+ restored = pickle.loads(pickle.dumps(lst))
+ assert isinstance(restored, tvm_ffi.List)
+ assert restored[0] == 1
+ assert restored[1] == "a"
+ assert isinstance(restored[2], tvm_ffi.Map)
+ assert restored[2]["k"] == 2
+
+
+def test_list_reverse() -> None:
+ lst = tvm_ffi.List([3, 1, 2])
+ lst.reverse()
+ assert tuple(lst) == (2, 1, 3)
+
+ empty = tvm_ffi.List()
+ empty.reverse()
+ assert tuple(empty) == ()
+
+ single = tvm_ffi.List([42])
+ single.reverse()
+ assert tuple(single) == (42,)
+
+
+def test_list_self_aliasing_slice_assignment() -> None:
+ lst = tvm_ffi.List([0, 1, 2, 3, 4])
+ lst[1:3] = lst
+ assert tuple(lst) == (0, 0, 1, 2, 3, 4, 3, 4)
+
+
+def test_list_is_mutable_and_shared() -> None:
+ lst = tvm_ffi.List([1, 2])
+ alias = lst
+ alias.append(3)
+ alias[0] = 10
+ assert tuple(lst) == (10, 2, 3)
+
+
+# ---------------------------------------------------------------------------
+# Cross-conversion tests: Array <-> List via SeqTypeTraitsBase
+# ---------------------------------------------------------------------------
+def test_seq_cross_conv_list_to_array_int() -> None:
+ """List<int> passed to a function expecting Array<int> (new behavior)."""
+ lst = tvm_ffi.List([10, 20, 30])
+ result = testing.schema_id_arr_int(lst)
+ assert isinstance(result, tvm_ffi.Array)
+ assert list(result) == [10, 20, 30]
+
+
+def test_seq_cross_conv_array_to_list_int() -> None:
+ """Array<int> passed to a function expecting List<int>."""
+ arr = tvm_ffi.Array([10, 20, 30])
+ result = testing.schema_id_list_int(arr)
+ assert isinstance(result, tvm_ffi.List)
+ assert list(result) == [10, 20, 30]
+
+
+def test_seq_cross_conv_list_to_array_str() -> None:
+ """List<String> passed to a function expecting Array<String>."""
+ lst = tvm_ffi.List(["a", "b", "c"])
+ result = testing.schema_id_arr_str(lst)
+ assert isinstance(result, tvm_ffi.Array)
+ assert list(result) == ["a", "b", "c"]
+
+
+def test_seq_cross_conv_array_to_list_str() -> None:
+ """Array<String> passed to a function expecting List<String>."""
+ arr = tvm_ffi.Array(["a", "b", "c"])
+ result = testing.schema_id_list_str(arr)
+ assert isinstance(result, tvm_ffi.List)
+ assert list(result) == ["a", "b", "c"]
+
+
+def test_seq_cross_conv_empty_list_to_array() -> None:
+ """Empty List passed to a function expecting Array<int>."""
+ lst = tvm_ffi.List([])
+ result = testing.schema_id_arr_int(lst)
+ assert isinstance(result, tvm_ffi.Array)
+ assert len(result) == 0
+
+
+def test_seq_cross_conv_empty_array_to_list() -> None:
+ """Empty Array passed to a function expecting List<int>."""
+ arr = tvm_ffi.Array([])
+ result = testing.schema_id_list_int(arr)
+ assert isinstance(result, tvm_ffi.List)
+ assert len(result) == 0
+
+
+def test_seq_cross_conv_python_list_to_array() -> None:
+ """Plain Python list passed to Array<int> function."""
+ result = testing.schema_id_arr_int([1, 2, 3])
+ assert isinstance(result, tvm_ffi.Array)
+ assert list(result) == [1, 2, 3]
+
+
+def test_seq_cross_conv_python_list_to_list() -> None:
+ """Plain Python list passed to List<int> function."""
+ result = testing.schema_id_list_int([1, 2, 3])
+ assert isinstance(result, tvm_ffi.List)
+ assert list(result) == [1, 2, 3]
+
+
+def test_seq_cross_conv_incompatible_list_to_array() -> None:
+ """List with incompatible element types should fail when cast to
Array<int>."""
+ lst = tvm_ffi.List(["not", "ints"])
+ with pytest.raises(TypeError):
+ testing.schema_id_arr_int(lst) # type: ignore[arg-type]
+
+
+def test_seq_cross_conv_incompatible_array_to_list() -> None:
+ """Array with incompatible element types should fail when cast to
List<int>."""
+ arr = tvm_ffi.Array(["not", "ints"])
+ with pytest.raises(TypeError):
+ testing.schema_id_list_int(arr) # type: ignore[arg-type]
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index 3f1554c..cee9084 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -189,6 +189,7 @@ def test_unregistered_object_fallback() -> None:
(tvm_ffi.core.Bytes, "ffi.Bytes", tvm_ffi.Object),
(tvm_ffi.Tensor, "ffi.Tensor", tvm_ffi.Object),
(tvm_ffi.Array, "ffi.Array", tvm_ffi.Object),
+ (tvm_ffi.List, "ffi.List", tvm_ffi.Object),
(tvm_ffi.Map, "ffi.Map", tvm_ffi.Object),
(tvm_ffi.access_path.AccessStep, "ffi.reflection.AccessStep",
tvm_ffi.Object),
(tvm_ffi.access_path.AccessPath, "ffi.reflection.AccessPath",
tvm_ffi.Object),