This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit cd2e404d9cd874a3c3d9f71c82c998b745be8636
Author: tqchen <[email protected]>
AuthorDate: Tue Apr 22 09:50:40 2025 -0400

    [FFI] Introduce Shape
---
 ffi/include/tvm/ffi/c_api.h                       |  57 ++++--
 ffi/include/tvm/ffi/container/container_details.h |   2 +-
 ffi/include/tvm/ffi/container/shape.h             | 219 ++++++++++++++++++++++
 ffi/include/tvm/ffi/dtype.h                       |   2 -
 ffi/include/tvm/ffi/object.h                      |  30 +--
 ffi/include/tvm/ffi/optional.h                    |   4 +
 ffi/include/tvm/ffi/string.h                      |   4 +-
 ffi/src/ffi/object.cc                             |   2 +
 ffi/tests/cpp/test_shape.cc                       |  72 +++++++
 9 files changed, 348 insertions(+), 44 deletions(-)

diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index 5627d46fa0..f0f609e4d1 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -58,40 +58,72 @@ typedef enum {
   // - `Any::type_index` is never `kTVMFFIRawStr`
   // - `AnyView::type_index` can be `kTVMFFIRawStr`
   //
-  // NOTE: kTVMFFIAny is a root type of everything
-  // we include it so TypeIndex captures all possible runtime values.
-  // `kTVMFFIAny` code will never appear in Any::type_index.
-  // However, it may appear in field annotations during reflection.
-  //
+  /*
+   * \brief The root type of all FFI objects.
+   *
+   * We include it so TypeIndex captures all possible runtime values.
+   * `kTVMFFIAny` code will never appear in Any::type_index.
+   * However, it may appear in field annotations during reflection.
+   */
   kTVMFFIAny = -1,
+  /*! \brief None/nullptr value */
   kTVMFFINone = 0,
+  /*! \brief POD int value */
   kTVMFFIInt = 1,
+  /*! \brief POD bool value */
   kTVMFFIBool = 2,
+  /*! \brief POD float value */
   kTVMFFIFloat = 3,
+  /*! \brief Opaque pointer object */
   kTVMFFIOpaquePtr = 4,
+  /*! \brief DLDataType */
   kTVMFFIDataType = 5,
+  /*! \brief DLDevice */
   kTVMFFIDevice = 6,
+  /*! \brief DLTensor* */
   kTVMFFIDLTensorPtr = 7,
+  /*! \brief const char**/
   kTVMFFIRawStr = 8,
+  /*! \brief TVMFFIByteArray* */
   kTVMFFIByteArrayPtr = 9,
+  /*! \brief R-value reference to ObjectRef */
   kTVMFFIObjectRValueRef = 10,
-  // [Section] Static Boxed: [kTVMFFIStaticObjectBegin, kTVMFFIDynObjectBegin)
-  // roughly order in terms of their ptential dependencies
+  /*! \brief Start of statically defined objects. */
   kTVMFFIStaticObjectBegin = 64,
+  /*!
+   * \brief Object, all objects starts with TVMFFIObject as its header.
+   * \note We will also add other fields
+   */
   kTVMFFIObject = 64,
+  /*!
+   * \brief String object, layout = { TVMFFIObject, TVMFFIByteArray, ... }
+   */
   kTVMFFIStr = 65,
+  /*!
+   * \brief Bytes object, layout = { TVMFFIObject, TVMFFIByteArray, ... }
+   */
   kTVMFFIBytes = 66,
+  /*! \brief Error object. */
   kTVMFFIError = 67,
+  /*! \brief Function object. */
   kTVMFFIFunc = 68,
+  /*! \brief Array object. */
   kTVMFFIArray = 69,
+  /*! \brief Map object. */
   kTVMFFIMap = 70,
-  kTVMFFIShapeTuple = 71,
+  /*!
+   * \brief Shape object, layout = { TVMFFIObject, { const int64_t*, size_t }, 
... }
+   */
+  kTVMFFIShape = 71,
+  /*!
+   * \brief NDArray object, layout = { TVMFFIObject, DLTensor, ... }
+   */
   kTVMFFINDArray = 72,
+  /*! \brief Runtime module object. */
   kTVMFFIRuntimeModule = 73,
   kTVMFFIStaticObjectEnd,
   // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo)
-  // kTVMFFIDynObject is used to indicate that the type index
-  // is dynamic and needs to be looked up at runtime
+  /*! \brief Start of type indices that are allocated at runtime. */
   kTVMFFIDynObjectBegin = 128
 #ifdef __cplusplus
 };
@@ -157,10 +189,9 @@ typedef struct TVMFFIAny {
 } TVMFFIAny;
 
 /*!
- *  \brief Byte array data structure used by String and Bytes.
+ * \brief Byte array data structure used by String and Bytes.
  *
- *  String and bytes follows the layout of C-style string,
- *  with a null-terminated character array and a size field.
+ * String and Bytes object layout = { TVMFFIObject, TVMFFIByteArray, ... }
  *
  * \note This byte array data structure layout differs in 32/64 bit platforms.
  *       as size_t equals to the size of the pointer, use this convetion to
diff --git a/ffi/include/tvm/ffi/container/container_details.h 
b/ffi/include/tvm/ffi/container/container_details.h
index 73cb34afda..bcdaa18ae7 100644
--- a/ffi/include/tvm/ffi/container/container_details.h
+++ b/ffi/include/tvm/ffi/container/container_details.h
@@ -19,7 +19,7 @@
 
 /*!
  * \file tvm/ffi/container/container_details.h
- * \brief Common utilities for container types.
+ * \brief Common utilities for typed container types.
  */
 #ifndef TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_
 #define TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_
diff --git a/ffi/include/tvm/ffi/container/shape.h 
b/ffi/include/tvm/ffi/container/shape.h
new file mode 100644
index 0000000000..f6f19d9289
--- /dev/null
+++ b/ffi/include/tvm/ffi/container/shape.h
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/ffi/shape.h
+ * \brief Container to store shape of an NDArray.
+ */
+#ifndef TVM_FFI_CONTAINER_SHAPE_H_
+#define TVM_FFI_CONTAINER_SHAPE_H_
+
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/type_traits.h>
+
+#include <vector>
+#include <ostream>
+
+namespace tvm {
+namespace ffi {
+
+/*! \brief An object representing a shape tuple. */
+class ShapeObj : public Object {
+ public:
+  using index_type = int64_t;
+
+  const int64_t* data;
+  size_t size;
+
+  /*! \brief Get "numel", meaning the number of elements of an array if the 
array has this shape */
+  int64_t Product() const {
+    int64_t product = 1;
+    for (size_t i = 0; i < this->size; ++i) {
+      product *= this->data[i];
+    }
+    return product;
+  }
+
+  static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIShape;
+  static constexpr const char* _type_key = StaticTypeKey::kTVMFFIShape;
+  TVM_FFI_DECLARE_STATIC_OBJECT_INFO(ShapeObj, Object);
+};
+
+namespace details {
+
+class ShapeObjStdImpl : public ShapeObj {
+ public:
+  explicit ShapeObjStdImpl(std::vector<int64_t> other) : data_{other} {
+    this->data = data_.data();
+    this->size = static_cast<size_t>(data_.size());
+  }
+
+ private:
+  std::vector<int64_t> data_;
+};
+
+TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeEmptyShape() {
+  ObjectPtr<ShapeObj> p = make_object<ShapeObj>();
+  p->data = nullptr;
+  p->size = 0;
+  return p;
+}
+
+// inplace shape allocation
+template <typename IterType>
+TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeInplaceShape(IterType begin, IterType 
end) {
+  size_t length = std::distance(begin, end);
+  ObjectPtr<ShapeObj> p = make_inplace_array_object<ShapeObj, int64_t>(length);
+  static_assert(alignof(ShapeObj) % alignof(int64_t) == 0);
+  static_assert(sizeof(ShapeObj) % alignof(int64_t) == 0);
+  int64_t* dest_data =
+      reinterpret_cast<int64_t*>(reinterpret_cast<char*>(p.get()) + 
sizeof(ShapeObj));
+  p->data = dest_data;
+  p->size = length;
+  std::copy(begin, end, dest_data);
+  return p;
+}
+
+}  // namespace details
+
+/*!
+ * \brief Reference to shape object.
+ */
+class Shape : public ObjectRef {
+ public:
+  /*! \brief The type of shape index element. */
+  using index_type = ShapeObj::index_type;
+
+  /*! \brief Default constructor */
+  Shape() : ObjectRef(details::MakeEmptyShape()) {}
+
+  /*!
+   * \brief Constructor from iterator
+   * \param begin begin of iterator
+   * \param end end of iterator
+   * \tparam IterType The type of iterator
+   */
+  template <typename IterType>
+  Shape(IterType begin, IterType end) : Shape(details::MakeInplaceShape(begin, 
end)) {}
+
+  /**
+   * \brief Constructor from Array<int64_t>
+   * \param shape The Array<int64_t>
+   *
+   * \note This constructor will copy the data content.
+   */
+  Shape(Array<int64_t> shape)  // NOLINT(*)
+      : Shape(shape.begin(), shape.end()) {}
+
+  /*!
+   * \brief constructor from initializer list
+   * \param shape The initializer list
+   */
+  Shape(std::initializer_list<int64_t> shape) : Shape(shape.begin(), 
shape.end()) {}
+
+  /*!
+   * \brief constructor from int64_t [N]
+   *
+   * \param other a int64_t array.
+   */
+  Shape(std::vector<int64_t> other)  // NOLINT(*)
+      : ObjectRef(make_object<details::ShapeObjStdImpl>(std::move(other))) {}
+
+  /*!
+   * \brief Return the data pointer
+   *
+   * \return const index_type* data pointer
+   */
+  const int64_t* data() const { return get()->data; }
+
+  /*!
+   * \brief Return the size of the shape tuple
+   *
+   * \return size_t shape tuple size
+   */
+  size_t size() const { return get()->size; }
+
+  /*!
+   * \brief Immutably read i-th element from the shape tuple.
+   * \param idx The index
+   * \return the i-th element.
+   */
+  int64_t operator[](size_t idx) const {
+    if (idx >= this->size()) {
+      TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size 
" << this->size();
+    }
+    return this->data()[idx];
+  }
+
+  /*!
+   * \brief Immutably read i-th element from the shape tuple.
+   * \param idx The index
+   * \return the i-th element.
+   */
+  int64_t at(size_t idx) const { return this->operator[](idx); }
+
+  /*! \return Whether shape tuple is empty */
+  bool empty() const { return size() == 0; }
+
+  /*! \return The first element of the shape tuple */
+  int64_t front() const { return this->at(0); }
+
+  /*! \return The last element of the shape tuple */
+  int64_t back() const { return this->at(this->size() - 1); }
+
+  /*! \return begin iterator */
+  const int64_t* begin() const { return get()->data; }
+
+  /*! \return end iterator */
+  const int64_t* end() const { return (get()->data + size()); }
+
+  /*! \return The product of the shape tuple */
+  int64_t Product() const { return get()->Product(); }
+
+  TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Shape, ObjectRef, ShapeObj);
+};
+
+inline std::ostream& operator<<(std::ostream& os, const Shape& shape) {
+  os << '[';
+  for (size_t i = 0; i < shape.size(); ++i) {
+    if (i != 0) {
+      os << ", ";
+    }
+    os << shape[i];
+  }
+  os << ']';
+  return os;
+}
+
+// Shape
+template <>
+inline constexpr bool use_default_type_traits_v<Shape> = false;
+
+// Allow auto conversion from Array<int64_t> to Shape, but not from Shape to 
Array<int64_t>
+template <>
+struct TypeTraits<Shape> : public ObjectRefWithFallbackTraitsBase<Shape, 
Array<int64_t>> {
+  static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIShape;
+  static TVM_FFI_INLINE Shape ConvertFallbackValue(Array<int64_t> src) { 
return Shape(src); }
+};
+
+}  // namespace ffi
+}  // namespace tvm
+
+#endif  // TVM_FFI_SHAPE_H_
diff --git a/ffi/include/tvm/ffi/dtype.h b/ffi/include/tvm/ffi/dtype.h
index 257b6bc158..f4a7d31390 100644
--- a/ffi/include/tvm/ffi/dtype.h
+++ b/ffi/include/tvm/ffi/dtype.h
@@ -20,8 +20,6 @@
 /*!
  * \file tvm/ffi/dtype.h
  * \brief Data type handling.
- *
- * This file contains convenient methods for holding
  */
 #ifndef TVM_FFI_DTYPE_H_
 #define TVM_FFI_DTYPE_H_
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index f8d3b7e0a9..e3cabaef3e 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -52,6 +52,9 @@ struct StaticTypeKey {
   static constexpr const char* kTVMFFIRawStr = "const char*";
   static constexpr const char* kTVMFFIByteArrayPtr = "TVMFFIByteArray*";
   static constexpr const char* kTVMFFIObjectRValueRef = "ObjectRValueRef";
+  static constexpr const char* kTVMFFIShape = "object.Shape";
+  static constexpr const char* kTVMFFIBytes = "object.Bytes";
+  static constexpr const char* kTVMFFIStr = "object.String";
 };
 
 /*!
@@ -60,32 +63,7 @@ struct StaticTypeKey {
  * \return the type key
  */
 inline std::string TypeIndexToTypeKey(int32_t type_index) {
-  switch (type_index) {
-    case TypeIndex::kTVMFFIAny:
-      return StaticTypeKey::kTVMFFIAny;
-    case TypeIndex::kTVMFFINone:
-      return StaticTypeKey::kTVMFFINone;
-    case TypeIndex::kTVMFFIBool:
-      return StaticTypeKey::kTVMFFIBool;
-    case TypeIndex::kTVMFFIInt:
-      return StaticTypeKey::kTVMFFIInt;
-    case TypeIndex::kTVMFFIFloat:
-      return StaticTypeKey::kTVMFFIFloat;
-    case TypeIndex::kTVMFFIOpaquePtr:
-      return StaticTypeKey::kTVMFFIOpaquePtr;
-    case TypeIndex::kTVMFFIDataType:
-      return StaticTypeKey::kTVMFFIDataType;
-    case TypeIndex::kTVMFFIDevice:
-      return StaticTypeKey::kTVMFFIDevice;
-    case TypeIndex::kTVMFFIRawStr:
-      return StaticTypeKey::kTVMFFIRawStr;
-    case TypeIndex::kTVMFFIObjectRValueRef:
-      return StaticTypeKey::kTVMFFIObjectRValueRef;
-    default: {
-      const TypeInfo* type_info = TVMFFIGetTypeInfo(type_index);
-      return type_info->type_key;
-    }
-  }
+  return TVMFFIGetTypeInfo(type_index)->type_key;
 }
 
 namespace details {
diff --git a/ffi/include/tvm/ffi/optional.h b/ffi/include/tvm/ffi/optional.h
index 4ac9bc198b..7b3f69ef99 100644
--- a/ffi/include/tvm/ffi/optional.h
+++ b/ffi/include/tvm/ffi/optional.h
@@ -20,6 +20,7 @@
 /*!
  * \file tvm/ffi/optional.h
  * \brief Runtime Optional container types.
+ * \note Optional<T> specializes for T is ObjectRef and used nullptr to 
indicate nullopt.
  */
 #ifndef TVM_FFI_OPTIONAL_H_
 #define TVM_FFI_OPTIONAL_H_
@@ -34,6 +35,9 @@
 namespace tvm {
 namespace ffi {
 
+// Note: We place optional in tvm/ffi instead of tvm/ffi/container
+// because optional itself is an inherent core component of the FFI system.
+
 template <typename T>
 inline constexpr bool is_optional_type_v = false;
 
diff --git a/ffi/include/tvm/ffi/string.h b/ffi/include/tvm/ffi/string.h
index 3ffc358d7d..70a3053511 100644
--- a/ffi/include/tvm/ffi/string.h
+++ b/ffi/include/tvm/ffi/string.h
@@ -63,7 +63,7 @@ class BytesObjBase : public Object {
 class BytesObj : public BytesObjBase {
  public:
   static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIBytes;
-  static constexpr const char* _type_key = "object.Bytes";
+  static constexpr const char* _type_key = StaticTypeKey::kTVMFFIBytes;
   static const constexpr bool _type_final = true;
   TVM_FFI_DECLARE_STATIC_OBJECT_INFO(BytesObj, Object);
 };
@@ -72,7 +72,7 @@ class BytesObj : public BytesObjBase {
 class StringObj : public BytesObjBase {
  public:
   static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIStr;
-  static constexpr const char* _type_key = "object.String";
+  static constexpr const char* _type_key = StaticTypeKey::kTVMFFIStr;
   static const constexpr bool _type_final = true;
   TVM_FFI_DECLARE_STATIC_OBJECT_INFO(StringObj, Object);
 };
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index eeaf857115..5703127c7d 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -242,6 +242,7 @@ class TypeTable {
                               Object::_type_child_slots, 
Object::_type_child_slots_can_overflow,
                               -1);
     // reserve the static types
+    ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFINone, 
TypeIndex::kTVMFFINone);
     ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIInt, TypeIndex::kTVMFFIInt);
     ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIFloat, 
TypeIndex::kTVMFFIFloat);
     ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIBool, 
TypeIndex::kTVMFFIBool);
@@ -252,6 +253,7 @@ class TypeTable {
     ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIByteArrayPtr, 
TypeIndex::kTVMFFIByteArrayPtr);
     ReserveBuiltinTypeIndex(StaticTypeKey::kTVMFFIObjectRValueRef,
                             TypeIndex::kTVMFFIObjectRValueRef);
+    // no need to reserve for object types as they will be registered
   }
 
   void ReserveBuiltinTypeIndex(const char* type_key, int32_t 
static_type_index) {
diff --git a/ffi/tests/cpp/test_shape.cc b/ffi/tests/cpp/test_shape.cc
new file mode 100644
index 0000000000..c6d8d5dbd8
--- /dev/null
+++ b/ffi/tests/cpp/test_shape.cc
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <gtest/gtest.h>
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/shape.h>
+
+namespace {
+
+using namespace tvm::ffi;
+
+TEST(Shape, Basic) {
+  Shape shape = Shape({1, 2, 3});
+  EXPECT_EQ(shape.size(), 3);
+  EXPECT_EQ(shape[0], 1);
+  EXPECT_EQ(shape[1], 2);
+  EXPECT_EQ(shape[2], 3);
+
+  Shape shape2 = Shape(Array<int64_t>({4, 5, 6, 7}));
+  EXPECT_EQ(shape2.size(), 4);
+  EXPECT_EQ(shape2[0], 4);
+  EXPECT_EQ(shape2[1], 5);
+  EXPECT_EQ(shape2[2], 6);
+  EXPECT_EQ(shape2[3], 7);
+
+  std::vector<int64_t> vec = {8, 9, 10};
+  Shape shape3 = Shape(std::move(vec));
+  EXPECT_EQ(shape3.size(), 3);
+  EXPECT_EQ(shape3[0], 8);
+  EXPECT_EQ(shape3[1], 9);
+  EXPECT_EQ(shape3[2], 10);
+  EXPECT_EQ(shape3.Product(), 8 * 9 * 10);
+
+  Shape shape4 = Shape();
+  EXPECT_EQ(shape4.size(), 0);
+  EXPECT_EQ(shape4.Product(), 1);
+}
+
+TEST(Shape, AnyConvert) {
+  Shape shape0 = Shape({1, 2, 3});
+  Any any0 = shape0;
+
+  Shape shape1 = any0;
+  EXPECT_EQ(shape1.size(), 3);
+  EXPECT_EQ(shape1[0], 1);
+  EXPECT_EQ(shape1[1], 2);
+  EXPECT_EQ(shape1[2], 3);
+
+  Array<Any> arr({1, 2});
+  AnyView any_view0 = arr;
+  Shape shape2 = any_view0;
+  EXPECT_EQ(shape2.size(), 2);
+  EXPECT_EQ(shape2[0], 1);
+  EXPECT_EQ(shape2[1], 2);
+}
+
+}  // namespace

Reply via email to