This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch tensor-abi
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git

commit c708337230aa58c4695731fd9a39628b14d837ee
Author: tqchen <[email protected]>
AuthorDate: Fri Sep 26 08:10:51 2025 -0400

    ABI hack
---
 include/tvm/ffi/container/tensor.h | 35 +++++++++++++++++++++++------------
 src/ffi/extra/structural_hash.cc   |  2 +-
 2 files changed, 24 insertions(+), 13 deletions(-)

diff --git a/include/tvm/ffi/container/tensor.h 
b/include/tvm/ffi/container/tensor.h
index 9f698de..c4bb050 100644
--- a/include/tvm/ffi/container/tensor.h
+++ b/include/tvm/ffi/container/tensor.h
@@ -120,8 +120,9 @@ inline size_t GetDataSize(const DLTensor& arr) {
 }
 
 /*! \brief An object representing a Tensor. */
-class TensorObj : public Object, public DLTensor {
+class TensorObj : public Object {
  public:
+  DLTensor *dl_tensor_ptr;
   /// \cond Doxygen_Suppress
   static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor;
   TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFITensor, TensorObj, 
Object);
@@ -142,7 +143,7 @@ class TensorObj : public Object, public DLTensor {
   DLManagedTensor* ToDLPack() const {
     TensorObj* self = const_cast<TensorObj*>(this);
     DLManagedTensor* ret = new DLManagedTensor();
-    ret->dl_tensor = *static_cast<DLTensor*>(self);
+    ret->dl_tensor = *static_cast<DLTensor*>(self->dl_tensor_ptr);
     ret->manager_ctx = self;
     ret->deleter = DLManagedTensorDeleter;
     details::ObjectUnsafe::IncRefObjectHandle(self);
@@ -165,7 +166,7 @@ class TensorObj : public Object, public DLTensor {
       DLManagedTensorVersioned* ret = new DLManagedTensorVersioned();
       ret->version.major = DLPACK_MAJOR_VERSION;
       ret->version.minor = DLPACK_MINOR_VERSION;
-      ret->dl_tensor = *static_cast<DLTensor*>(from);
+      ret->dl_tensor = *static_cast<DLTensor*>(from->dl_tensor_ptr);
       ret->manager_ctx = from;
       ret->deleter = EmbeddedDLManagedTensorVersionedDeleter;
       ret->flags = 0;
@@ -227,7 +228,7 @@ namespace details {
  * The underlying allocator needs to be implemented by user.
  */
 template <typename TNDAlloc>
-class TensorObjFromNDAlloc : public TensorObj {
+class TensorObjFromNDAlloc : public TensorObj, private DLTensor {
  public:
   template <typename... ExtraArgs>
   TensorObjFromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype, 
DLDevice device,
@@ -243,6 +244,7 @@ class TensorObjFromNDAlloc : public TensorObj {
     this->shape_data_ = std::move(shape);
     this->strides_data_ = std::move(strides);
     alloc_.AllocData(static_cast<DLTensor*>(this), 
std::forward<ExtraArgs>(extra_args)...);
+    this->dl_tensor_ptr = static_cast<DLTensor*>(this);
   }
 
   ~TensorObjFromNDAlloc() { alloc_.FreeData(static_cast<DLTensor*>(this)); }
@@ -256,10 +258,10 @@ template <typename TDLPackManagedTensor>
 class TensorObjFromDLPack : public TensorObj {
  public:
   explicit TensorObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) 
{
-    *static_cast<DLTensor*>(this) = tensor_->dl_tensor;
+    this->dl_tensor_ptr = &(tensor_->dl_tensor);
     if (tensor_->dl_tensor.strides == nullptr) {
       Shape strides = Shape::StridesFromShape(tensor_->dl_tensor.shape, 
tensor_->dl_tensor.ndim);
-      this->strides = const_cast<int64_t*>(strides.data());
+      dl_tensor_ptr->strides = const_cast<int64_t*>(strides.data());
       this->strides_data_ = std::move(strides);
     }
   }
@@ -292,7 +294,7 @@ class Tensor : public ObjectRef {
   tvm::ffi::Shape shape() const {
     TensorObj* obj = get_mutable();
     if (!obj->shape_data_.has_value()) {
-      obj->shape_data_ = tvm::ffi::Shape(obj->shape, obj->shape + obj->ndim);
+      obj->shape_data_ = tvm::ffi::Shape(obj->dl_tensor_ptr->shape, 
obj->dl_tensor_ptr->shape + obj->dl_tensor_ptr->ndim);
     }
     return *(obj->shape_data_);
   }
@@ -302,9 +304,9 @@ class Tensor : public ObjectRef {
    */
   tvm::ffi::Shape strides() const {
     TensorObj* obj = get_mutable();
-    TVM_FFI_ICHECK(obj->strides != nullptr);
+    TVM_FFI_ICHECK(obj->dl_tensor_ptr->strides != nullptr);
     if (!obj->strides_data_.has_value()) {
-      obj->strides_data_ = tvm::ffi::Shape(obj->strides, obj->strides + 
obj->ndim);
+      obj->strides_data_ = tvm::ffi::Shape(obj->dl_tensor_ptr->strides, 
obj->dl_tensor_ptr->strides + obj->dl_tensor_ptr->ndim);
     }
     return *(obj->strides_data_);
   }
@@ -317,13 +319,13 @@ class Tensor : public ObjectRef {
    * \brief Check if the Tensor is contiguous.
    * \return True if the Tensor is contiguous, false otherwise.
    */
-  bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); }
+  bool IsContiguous() const { return tvm::ffi::IsContiguous(*operator->()); }
   /*!
    * \brief Check if the Tensor data is aligned to the given alignment.
    * \param alignment The alignment to check.
    * \return True if the Tensor data is aligned to the given alignment, false 
otherwise.
    */
-  bool IsAligned(size_t alignment) const { return tvm::ffi::IsAligned(*get(), 
alignment); }
+  bool IsAligned(size_t alignment) const { return 
tvm::ffi::IsAligned(*operator->(), alignment); }
   /*!
    * \brief Create a Tensor from a NDAllocator.
    * \param alloc The NDAllocator.
@@ -449,8 +451,17 @@ class Tensor : public ObjectRef {
    */
   DLManagedTensorVersioned* ToDLPackVersioned() const { return 
get_mutable()->ToDLPackVersioned(); }
 
+  Tensor() = default;
+  explicit Tensor(::tvm::ffi::ObjectPtr<TensorObj> n) : ObjectRef(n) {}
+  explicit Tensor(::tvm::ffi::UnsafeInit tag) : ObjectRef(tag) {}
+  TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Tensor)
+  const DLTensor* operator->() const { return 
static_cast<TensorObj*>(data_.get())->dl_tensor_ptr; }
+  const TensorObj* get() const { return static_cast<const 
TensorObj*>(data_.get()); }
+  [[maybe_unused]] static constexpr bool _type_is_nullable = true;
+  using ContainerType = TensorObj;
+
   /// \cond Doxygen_Suppress
-  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tensor, ObjectRef, TensorObj);
+  //TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tensor, ObjectRef, TensorObj);
   /// \endcond
 
  protected:
diff --git a/src/ffi/extra/structural_hash.cc b/src/ffi/extra/structural_hash.cc
index f6463af..b3fa140 100644
--- a/src/ffi/extra/structural_hash.cc
+++ b/src/ffi/extra/structural_hash.cc
@@ -268,7 +268,7 @@ class StructuralHashHandler {
   }
 
   uint64_t HashTensor(Tensor tensor) {
-    uint64_t hash_value = details::StableHashCombine(tensor->GetTypeKeyHash(), 
tensor->ndim);
+    uint64_t hash_value = 
details::StableHashCombine(tensor.get()->GetTypeKeyHash(), tensor->ndim);
     for (int i = 0; i < tensor->ndim; ++i) {
       hash_value = details::StableHashCombine(hash_value, tensor->shape[i]);
     }

Reply via email to