This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 0a6dc7a5f646f74fb025c0ab43fa4adf4dce2003 Author: tqchen <[email protected]> AuthorDate: Tue Apr 22 09:50:47 2025 -0400 migrate shape to use ffi::Shape --- include/tvm/runtime/container/shape_tuple.h | 165 +--------------------------- include/tvm/runtime/object.h | 2 +- src/runtime/relax_vm/attn_utils.h | 2 +- 3 files changed, 5 insertions(+), 164 deletions(-) diff --git a/include/tvm/runtime/container/shape_tuple.h b/include/tvm/runtime/container/shape_tuple.h index 6532fe0fc9..6a0497049f 100644 --- a/include/tvm/runtime/container/shape_tuple.h +++ b/include/tvm/runtime/container/shape_tuple.h @@ -27,173 +27,14 @@ #include <ostream> #include <utility> #include <vector> - +#include <tvm/ffi/container/shape.h> #include "./base.h" namespace tvm { namespace runtime { -/*! \brief An object representing a shape tuple. */ -class ShapeTupleObj : public Object { - public: - /*! \brief The type of shape index element. */ - using index_type = int64_t; - /*! \brief The pointer to shape tuple data. */ - index_type* data; - /*! \brief The size of the shape tuple object. */ - uint64_t size; - - /*! \brief Get "numel", meaning the number of elements of an array if the array has this shape */ - index_type Product() const; - - static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeShapeTuple; - static constexpr const char* _type_key = "runtime.ShapeTuple"; - TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ShapeTupleObj, Object); - - private: - /*! \brief ShapeTuple object which is moved from std::vector container. */ - class FromStd; - - friend class ShapeTuple; -}; - -/*! \brief An object representing shape tuple moved from std::vector. */ -class ShapeTupleObj::FromStd : public ShapeTupleObj { - public: - /*! \brief The type of shape index element. */ - using index_type = ShapeTupleObj::index_type; - /*! - * \brief Construct a new FromStd object - * - * \param other The moved/copied std::vector object - * - * \note If user passes const reference, it will trigger copy. If it's rvalue, - * it will be moved into other. - */ - explicit FromStd(std::vector<index_type> other) : data_container{other} {} - - private: - /*! \brief Container that holds the memory. */ - std::vector<index_type> data_container; - - friend class ShapeTuple; -}; - -/*! - * \brief Reference to shape tuple objects. - */ -class ShapeTuple : public ObjectRef { - public: - /*! \brief The type of shape index element. */ - using index_type = ShapeTupleObj::index_type; - - /*! - * \brief Construct an empty shape tuple. - */ - ShapeTuple() : ShapeTuple(std::vector<index_type>()) {} - - /*! - * \brief Constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template <typename IterType> - ShapeTuple(IterType begin, IterType end) : ShapeTuple(std::vector<index_type>(begin, end)) {} - - /*! - * \brief constructor from initializer list - * \param shape The initializer list - */ - ShapeTuple(std::initializer_list<index_type> shape) : ShapeTuple(shape.begin(), shape.end()) {} - - /*! - * \brief Construct a new ShapeTuple object - * - * \param shape The moved/copied std::vector object - * - * \note If user passes const reference, it will trigger copy. If it's rvalue, - * it will be moved into other. - */ - ShapeTuple(std::vector<index_type> shape); // NOLINT(*) - - /*! - * \brief Return the data pointer - * - * \return const index_type* data pointer - */ - const index_type* data() const { return get()->data; } - - /*! - * \brief Return the size of the shape tuple - * - * \return size_t shape tuple size - */ - size_t size() const { return get()->size; } - - /*! - * \brief Immutably read i-th element from the shape tuple. - * \param idx The index - * \return the i-th element. - */ - index_type operator[](size_t idx) const { - ICHECK(idx < this->size()) << "IndexError: indexing " << idx << " on an array of size " - << this->size(); - return this->data()[idx]; - } - - /*! - * \brief Immutably read i-th element from the shape tuple. - * \param idx The index - * \return the i-th element. - */ - index_type at(size_t idx) const { return this->operator[](idx); } - - /*! \return Whether shape tuple is empty */ - bool empty() const { return size() == 0; } - - /*! \return The first element of the shape tuple */ - index_type front() const { return this->at(0); } - - /*! \return The last element of the shape tuple */ - index_type back() const { return this->at(this->size() - 1); } - - /*! \return begin iterator */ - const index_type* begin() const { return get()->data; } - - /*! \return end iterator */ - const index_type* end() const { return (get()->data + size()); } - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeTuple, ObjectRef, ShapeTupleObj); -}; - -inline ShapeTuple::ShapeTuple(std::vector<index_type> shape) { - auto ptr = make_object<ShapeTupleObj::FromStd>(std::move(shape)); - ptr->size = ptr->data_container.size(); - ptr->data = ptr->data_container.data(); - data_ = std::move(ptr); -} - -inline ShapeTupleObj::index_type ShapeTupleObj::Product() const { - index_type numel = 1; - for (int i = 0, n = this->size; i < n; ++i) { - numel *= this->data[i]; - } - return numel; -} - -inline std::ostream& operator<<(std::ostream& os, const ShapeTuple& shape) { - os << '['; - for (size_t i = 0; i < shape->size; ++i) { - if (i != 0) { - os << ", "; - } - os << shape->data[i]; - } - os << ']'; - return os; -} - +using ShapeTuple = tvm::ffi::Shape; +using ShapeTupleObj = tvm::ffi::ShapeObj; using IntTuple = ShapeTuple; using IntTupleObj = ShapeTupleObj; diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index dd7bdacd63..4dcb179fe0 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -58,7 +58,7 @@ enum TypeIndex : int32_t { /*! \brief runtime::NDArray. */ kRuntimeNDArray = TVMFFITypeIndex::kTVMFFINDArray, /*! \brief runtime::ShapeTuple. */ - kRuntimeShapeTuple = TVMFFITypeIndex::kTVMFFIShapeTuple, + kRuntimeShapeTuple = TVMFFITypeIndex::kTVMFFIShape, // Extra builtin static index here kCustomStaticIndex = TVMFFITypeIndex::kTVMFFIStaticObjectEnd, /*! \brief runtime::PackedFunc. */ diff --git a/src/runtime/relax_vm/attn_utils.h b/src/runtime/relax_vm/attn_utils.h index 8138aa7bbd..3e53d6bbc2 100644 --- a/src/runtime/relax_vm/attn_utils.h +++ b/src/runtime/relax_vm/attn_utils.h @@ -730,7 +730,7 @@ class PlainPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager { if (shape.defined()) { ICHECK_EQ(shape.value().size(), 1); copy_dst.ndim = 1; - copy_dst.shape = shape.value()->data; + copy_dst.shape = const_cast<int64_t*>(shape.value()->data); } copy_dst.byte_offset = dst_elem_offset * sizeof(int32_t);
