This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit cd2e404d9cd874a3c3d9f71c82c998b745be8636 Author: tqchen <[email protected]> AuthorDate: Tue Apr 22 09:50:40 2025 -0400 [FFI] Introduce Shape --- ffi/include/tvm/ffi/c_api.h | 57 ++++-- ffi/include/tvm/ffi/container/container_details.h | 2 +- ffi/include/tvm/ffi/container/shape.h | 219 ++++++++++++++++++++++ ffi/include/tvm/ffi/dtype.h | 2 - ffi/include/tvm/ffi/object.h | 30 +-- ffi/include/tvm/ffi/optional.h | 4 + ffi/include/tvm/ffi/string.h | 4 +- ffi/src/ffi/object.cc | 2 + ffi/tests/cpp/test_shape.cc | 72 +++++++ 9 files changed, 348 insertions(+), 44 deletions(-) diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 5627d46fa0..f0f609e4d1 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -58,40 +58,72 @@ typedef enum { // - `Any::type_index` is never `kTVMFFIRawStr` // - `AnyView::type_index` can be `kTVMFFIRawStr` // - // NOTE: kTVMFFIAny is a root type of everything - // we include it so TypeIndex captures all possible runtime values. - // `kTVMFFIAny` code will never appear in Any::type_index. - // However, it may appear in field annotations during reflection. - // + /* + * \brief The root type of all FFI objects. + * + * We include it so TypeIndex captures all possible runtime values. + * `kTVMFFIAny` code will never appear in Any::type_index. + * However, it may appear in field annotations during reflection. + */ kTVMFFIAny = -1, + /*! \brief None/nullptr value */ kTVMFFINone = 0, + /*! \brief POD int value */ kTVMFFIInt = 1, + /*! \brief POD bool value */ kTVMFFIBool = 2, + /*! \brief POD float value */ kTVMFFIFloat = 3, + /*! \brief Opaque pointer object */ kTVMFFIOpaquePtr = 4, + /*! \brief DLDataType */ kTVMFFIDataType = 5, + /*! \brief DLDevice */ kTVMFFIDevice = 6, + /*! \brief DLTensor* */ kTVMFFIDLTensorPtr = 7, + /*! \brief const char**/ kTVMFFIRawStr = 8, + /*! \brief TVMFFIByteArray* */ kTVMFFIByteArrayPtr = 9, + /*! \brief R-value reference to ObjectRef */ kTVMFFIObjectRValueRef = 10, - // [Section] Static Boxed: [kTVMFFIStaticObjectBegin, kTVMFFIDynObjectBegin) - // roughly order in terms of their ptential dependencies + /*! \brief Start of statically defined objects. */ kTVMFFIStaticObjectBegin = 64, + /*! + * \brief Object, all objects starts with TVMFFIObject as its header. + * \note We will also add other fields + */ kTVMFFIObject = 64, + /*! + * \brief String object, layout = { TVMFFIObject, TVMFFIByteArray, ... } + */ kTVMFFIStr = 65, + /*! + * \brief Bytes object, layout = { TVMFFIObject, TVMFFIByteArray, ... } + */ kTVMFFIBytes = 66, + /*! \brief Error object. */ kTVMFFIError = 67, + /*! \brief Function object. */ kTVMFFIFunc = 68, + /*! \brief Array object. */ kTVMFFIArray = 69, + /*! \brief Map object. */ kTVMFFIMap = 70, - kTVMFFIShapeTuple = 71, + /*! + * \brief Shape object, layout = { TVMFFIObject, { const int64_t*, size_t }, ... } + */ + kTVMFFIShape = 71, + /*! + * \brief NDArray object, layout = { TVMFFIObject, DLTensor, ... } + */ kTVMFFINDArray = 72, + /*! \brief Runtime module object. */ kTVMFFIRuntimeModule = 73, kTVMFFIStaticObjectEnd, // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo) - // kTVMFFIDynObject is used to indicate that the type index - // is dynamic and needs to be looked up at runtime + /*! \brief Start of type indices that are allocated at runtime. */ kTVMFFIDynObjectBegin = 128 #ifdef __cplusplus }; @@ -157,10 +189,9 @@ typedef struct TVMFFIAny { } TVMFFIAny; /*! - * \brief Byte array data structure used by String and Bytes. + * \brief Byte array data structure used by String and Bytes. * - * String and bytes follows the layout of C-style string, - * with a null-terminated character array and a size field. + * String and Bytes object layout = { TVMFFIObject, TVMFFIByteArray, ... } * * \note This byte array data structure layout differs in 32/64 bit platforms. * as size_t equals to the size of the pointer, use this convetion to diff --git a/ffi/include/tvm/ffi/container/container_details.h b/ffi/include/tvm/ffi/container/container_details.h index 73cb34afda..bcdaa18ae7 100644 --- a/ffi/include/tvm/ffi/container/container_details.h +++ b/ffi/include/tvm/ffi/container/container_details.h @@ -19,7 +19,7 @@ /*! * \file tvm/ffi/container/container_details.h - * \brief Common utilities for container types. + * \brief Common utilities for typed container types. */ #ifndef TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ #define TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ diff --git a/ffi/include/tvm/ffi/container/shape.h b/ffi/include/tvm/ffi/container/shape.h new file mode 100644 index 0000000000..f6f19d9289 --- /dev/null +++ b/ffi/include/tvm/ffi/container/shape.h @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/shape.h + * \brief Container to store shape of an NDArray. + */ +#ifndef TVM_FFI_CONTAINER_SHAPE_H_ +#define TVM_FFI_CONTAINER_SHAPE_H_ + +#include <tvm/ffi/container/array.h> +#include <tvm/ffi/error.h> +#include <tvm/ffi/type_traits.h> + +#include <vector> +#include <ostream> + +namespace tvm { +namespace ffi { + +/*! \brief An object representing a shape tuple. */ +class ShapeObj : public Object { + public: + using index_type = int64_t; + + const int64_t* data; + size_t size; + + /*! \brief Get "numel", meaning the number of elements of an array if the array has this shape */ + int64_t Product() const { + int64_t product = 1; + for (size_t i = 0; i < this->size; ++i) { + product *= this->data[i]; + } + return product; + } + + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIShape; + static constexpr const char* _type_key = StaticTypeKey::kTVMFFIShape; + TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ShapeObj, Object); +}; + +namespace details { + +class ShapeObjStdImpl : public ShapeObj { + public: + explicit ShapeObjStdImpl(std::vector<int64_t> other) : data_{other} { + this->data = data_.data(); + this->size = static_cast<size_t>(data_.size()); + } + + private: + std::vector<int64_t> data_; +}; + +TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeEmptyShape() { + ObjectPtr<ShapeObj> p = make_object<ShapeObj>(); + p->data = nullptr; + p->size = 0; + return p; +} + +// inplace shape allocation +template <typename IterType> +TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeInplaceShape(IterType begin, IterType end) { + size_t length = std::distance(begin, end); + ObjectPtr<ShapeObj> p = make_inplace_array_object<ShapeObj, int64_t>(length); + static_assert(alignof(ShapeObj) % alignof(int64_t) == 0); + static_assert(sizeof(ShapeObj) % alignof(int64_t) == 0); + int64_t* dest_data = + reinterpret_cast<int64_t*>(reinterpret_cast<char*>(p.get()) + sizeof(ShapeObj)); + p->data = dest_data; + p->size = length; + std::copy(begin, end, dest_data); + return p; +} + +} // namespace details + +/*! + * \brief Reference to shape object. + */ +class Shape : public ObjectRef { + public: + /*! \brief The type of shape index element. */ + using index_type = ShapeObj::index_type; + + /*! \brief Default constructor */ + Shape() : ObjectRef(details::MakeEmptyShape()) {} + + /*! + * \brief Constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template <typename IterType> + Shape(IterType begin, IterType end) : Shape(details::MakeInplaceShape(begin, end)) {} + + /** + * \brief Constructor from Array<int64_t> + * \param shape The Array<int64_t> + * + * \note This constructor will copy the data content. + */ + Shape(Array<int64_t> shape) // NOLINT(*) + : Shape(shape.begin(), shape.end()) {} + + /*! + * \brief constructor from initializer list + * \param shape The initializer list + */ + Shape(std::initializer_list<int64_t> shape) : Shape(shape.begin(), shape.end()) {} + + /*! + * \brief constructor from int64_t [N] + * + * \param other a int64_t array. + */ + Shape(std::vector<int64_t> other) // NOLINT(*) + : ObjectRef(make_object<details::ShapeObjStdImpl>(std::move(other))) {} + + /*! + * \brief Return the data pointer + * + * \return const index_type* data pointer + */ + const int64_t* data() const { return get()->data; } + + /*! + * \brief Return the size of the shape tuple + * + * \return size_t shape tuple size + */ + size_t size() const { return get()->size; } + + /*! + * \brief Immutably read i-th element from the shape tuple. + * \param idx The index + * \return the i-th element. + */ + int64_t operator[](size_t idx) const { + if (idx >= this->size()) { + TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size " << this->size(); + } + return this->data()[idx]; + } + + /*! + * \brief Immutably read i-th element from the shape tuple. + * \param idx The index + * \return the i-th element. + */ + int64_t at(size_t idx) const { return this->operator[](idx); } + + /*! \return Whether shape tuple is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the shape tuple */ + int64_t front() const { return this->at(0); } + + /*! \return The last element of the shape tuple */ + int64_t back() const { return this->at(this->size() - 1); } + + /*! \return begin iterator */ + const int64_t* begin() const { return get()->data; } + + /*! \return end iterator */ + const int64_t* end() const { return (get()->data + size()); } + + /*! \return The product of the shape tuple */ + int64_t Product() const { return get()->Product(); } + + TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Shape, ObjectRef, ShapeObj); +}; + +inline std::ostream& operator<<(std::ostream& os, const Shape& shape) { + os << '['; + for (size_t i = 0; i < shape.size(); ++i) { + if (i != 0) { + os << ", "; + } + os << shape[i]; + } + os << ']'; + return os; +} + +// Shape +template <> +inline constexpr bool use_default_type_traits_v<Shape> = false; + +// Allow auto conversion from Array<int64_t> to Shape, but not from Shape to Array<int64_t> +template <> +struct TypeTraits<Shape> : public ObjectRefWithFallbackTraitsBase<Shape, Array<int64_t>> { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIShape; + static TVM_FFI_INLINE Shape ConvertFallbackValue(Array<int64_t> src) { return Shape(src); } +}; + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_SHAPE_H_ diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h index 257b6bc158..f4a7d31390 100644 --- a/ffi/include/tvm/ffi/dtype.h +++ b/ffi/include/tvm/ffi/dtype.h @@ -20,8 +20,6 @@ /*! * \file tvm/ffi/dtype.h * \brief Data type handling. - * - * This file contains convenient methods for holding */ #ifndef TVM_FFI_DTYPE_H_ #define TVM_FFI_DTYPE_H_ diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index f8d3b7e0a9..e3cabaef3e 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -52,6 +52,9 @@ struct StaticTypeKey { static constexpr const char* kTVMFFIRawStr = "const char*"; static constexpr const char* kTVMFFIByteArrayPtr = "TVMFFIByteArray*"; static constexpr const char* kTVMFFIObjectRValueRef = "ObjectRValueRef"; + static constexpr const char* kTVMFFIShape = "object.Shape"; + static constexpr const char* kTVMFFIBytes = "object.Bytes"; + static constexpr const char* kTVMFFIStr = "object.String"; }; /*! @@ -60,32 +63,7 @@ struct StaticTypeKey { * \return the type key */ inline std::string TypeIndexToTypeKey(int32_t type_index) { - switch (type_index) { - case TypeIndex::kTVMFFIAny: - return StaticTypeKey::kTVMFFIAny; - case TypeIndex::kTVMFFINone: - return StaticTypeKey::kTVMFFINone; - case TypeIndex::kTVMFFIBool: - return StaticTypeKey::kTVMFFIBool; - case TypeIndex::kTVMFFIInt: - return StaticTypeKey::kTVMFFIInt; - case TypeIndex::kTVMFFIFloat: - return StaticTypeKey::kTVMFFIFloat; - case TypeIndex::kTVMFFIOpaquePtr: - return StaticTypeKey::kTVMFFIOpaquePtr; - case TypeIndex::kTVMFFIDataType: - return StaticTypeKey::kTVMFFIDataType; - case TypeIndex::kTVMFFIDevice: - return StaticTypeKey::kTVMFFIDevice; - case TypeIndex::kTVMFFIRawStr: - return StaticTypeKey::kTVMFFIRawStr; - case TypeIndex::kTVMFFIObjectRValueRef: - return StaticTypeKey::kTVMFFIObjectRValueRef; - default: { - const TypeInfo* type_info = TVMFFIGetTypeInfo(type_index); - return type_info->type_key; - } - } + return TVMFFIGetTypeInfo(type_index)->type_key; } namespace details { diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h index 4ac9bc198b..7b3f69ef99 100644 --- a/ffi/include/tvm/ffi/optional.h +++ b/ffi/include/tvm/ffi/optional.h @@ -20,6 +20,7 @@ /*! * \file tvm/ffi/optional.h * \brief Runtime Optional container types. + * \note Optional<T> specializes for T is ObjectRef and used nullptr to indicate nullopt. */ #ifndef TVM_FFI_OPTIONAL_H_ #define TVM_FFI_OPTIONAL_H_ @@ -34,6 +35,9 @@ namespace tvm { namespace ffi { +// Note: We place optional in tvm/ffi instead of tvm/ffi/container +// because optional itself is an inherent core component of the FFI system. + template <typename T> inline constexpr bool is_optional_type_v = false; diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h index 3ffc358d7d..70a3053511 100644 --- a/ffi/include/tvm/ffi/string.h +++ b/ffi/include/tvm/ffi/string.h @@ -63,7 +63,7 @@ class BytesObjBase : public Object { class BytesObj : public BytesObjBase { public: static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIBytes; - static constexpr const char* _type_key = "object.Bytes"; + static constexpr const char* _type_key = StaticTypeKey::kTVMFFIBytes; static const constexpr bool _type_final = true; TVM_FFI_DECLARE_STATIC_OBJECT_INFO(BytesObj, Object); }; @@ -72,7 +72,7 @@ class BytesObj : public BytesObjBase { class StringObj : public BytesObjBase { public: static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIStr; - static constexpr const char* _type_key = "object.String"; + static constexpr const char* _type_key = StaticTypeKey::kTVMFFIStr; static const constexpr bool _type_final = true; TVM_FFI_DECLARE_STATIC_OBJECT_INFO(StringObj, Object); }; diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc index eeaf857115..5703127c7d 100644 --- a/ffi/src/ffi/object.cc +++ b/ffi/src/ffi/object.cc @@ -242,6 +242,7 @@ class TypeTable { Object::_type_child_slots, Object::_type_child_slots_can_overflow, -1); // reserve the static types + ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFINone, TypeIndex::kTVMFFINone); ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIInt, TypeIndex::kTVMFFIInt); ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIFloat, TypeIndex::kTVMFFIFloat); ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIBool, TypeIndex::kTVMFFIBool); @@ -252,6 +253,7 @@ class TypeTable { ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIByteArrayPtr, TypeIndex::kTVMFFIByteArrayPtr); ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIObjectRValueRef, TypeIndex::kTVMFFIObjectRValueRef); + // no need to reserve for object types as they will be registered } void ReserveBuiltinTypeIndex(const char* type_key, int32_t static_type_index) { diff --git a/ffi/tests/cpp/test_shape.cc b/ffi/tests/cpp/test_shape.cc new file mode 100644 index 0000000000..c6d8d5dbd8 --- /dev/null +++ b/ffi/tests/cpp/test_shape.cc @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <gtest/gtest.h> +#include <tvm/ffi/container/array.h> +#include <tvm/ffi/container/shape.h> + +namespace { + +using namespace tvm::ffi; + +TEST(Shape, Basic) { + Shape shape = Shape({1, 2, 3}); + EXPECT_EQ(shape.size(), 3); + EXPECT_EQ(shape[0], 1); + EXPECT_EQ(shape[1], 2); + EXPECT_EQ(shape[2], 3); + + Shape shape2 = Shape(Array<int64_t>({4, 5, 6, 7})); + EXPECT_EQ(shape2.size(), 4); + EXPECT_EQ(shape2[0], 4); + EXPECT_EQ(shape2[1], 5); + EXPECT_EQ(shape2[2], 6); + EXPECT_EQ(shape2[3], 7); + + std::vector<int64_t> vec = {8, 9, 10}; + Shape shape3 = Shape(std::move(vec)); + EXPECT_EQ(shape3.size(), 3); + EXPECT_EQ(shape3[0], 8); + EXPECT_EQ(shape3[1], 9); + EXPECT_EQ(shape3[2], 10); + EXPECT_EQ(shape3.Product(), 8 * 9 * 10); + + Shape shape4 = Shape(); + EXPECT_EQ(shape4.size(), 0); + EXPECT_EQ(shape4.Product(), 1); +} + +TEST(Shape, AnyConvert) { + Shape shape0 = Shape({1, 2, 3}); + Any any0 = shape0; + + Shape shape1 = any0; + EXPECT_EQ(shape1.size(), 3); + EXPECT_EQ(shape1[0], 1); + EXPECT_EQ(shape1[1], 2); + EXPECT_EQ(shape1[2], 3); + + Array<Any> arr({1, 2}); + AnyView any_view0 = arr; + Shape shape2 = any_view0; + EXPECT_EQ(shape2.size(), 2); + EXPECT_EQ(shape2[0], 1); + EXPECT_EQ(shape2[1], 2); +} + +} // namespace
