This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch tensor-abi in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
commit c708337230aa58c4695731fd9a39628b14d837ee Author: tqchen <[email protected]> AuthorDate: Fri Sep 26 08:10:51 2025 -0400 ABI hack --- include/tvm/ffi/container/tensor.h | 35 +++++++++++++++++++++++------------ src/ffi/extra/structural_hash.cc | 2 +- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/include/tvm/ffi/container/tensor.h b/include/tvm/ffi/container/tensor.h index 9f698de..c4bb050 100644 --- a/include/tvm/ffi/container/tensor.h +++ b/include/tvm/ffi/container/tensor.h @@ -120,8 +120,9 @@ inline size_t GetDataSize(const DLTensor& arr) { } /*! \brief An object representing a Tensor. */ -class TensorObj : public Object, public DLTensor { +class TensorObj : public Object { public: + DLTensor *dl_tensor_ptr; /// \cond Doxygen_Suppress static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor; TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFITensor, TensorObj, Object); @@ -142,7 +143,7 @@ class TensorObj : public Object, public DLTensor { DLManagedTensor* ToDLPack() const { TensorObj* self = const_cast<TensorObj*>(this); DLManagedTensor* ret = new DLManagedTensor(); - ret->dl_tensor = *static_cast<DLTensor*>(self); + ret->dl_tensor = *static_cast<DLTensor*>(self->dl_tensor_ptr); ret->manager_ctx = self; ret->deleter = DLManagedTensorDeleter; details::ObjectUnsafe::IncRefObjectHandle(self); @@ -165,7 +166,7 @@ class TensorObj : public Object, public DLTensor { DLManagedTensorVersioned* ret = new DLManagedTensorVersioned(); ret->version.major = DLPACK_MAJOR_VERSION; ret->version.minor = DLPACK_MINOR_VERSION; - ret->dl_tensor = *static_cast<DLTensor*>(from); + ret->dl_tensor = *static_cast<DLTensor*>(from->dl_tensor_ptr); ret->manager_ctx = from; ret->deleter = EmbeddedDLManagedTensorVersionedDeleter; ret->flags = 0; @@ -227,7 +228,7 @@ namespace details { * The underlying allocator needs to be implemented by user. */ template <typename TNDAlloc> -class TensorObjFromNDAlloc : public TensorObj { +class TensorObjFromNDAlloc : public TensorObj, private DLTensor { public: template <typename... ExtraArgs> TensorObjFromNDAlloc(TNDAlloc alloc, ffi::Shape shape, DLDataType dtype, DLDevice device, @@ -243,6 +244,7 @@ class TensorObjFromNDAlloc : public TensorObj { this->shape_data_ = std::move(shape); this->strides_data_ = std::move(strides); alloc_.AllocData(static_cast<DLTensor*>(this), std::forward<ExtraArgs>(extra_args)...); + this->dl_tensor_ptr = static_cast<DLTensor*>(this); } ~TensorObjFromNDAlloc() { alloc_.FreeData(static_cast<DLTensor*>(this)); } @@ -256,10 +258,10 @@ template <typename TDLPackManagedTensor> class TensorObjFromDLPack : public TensorObj { public: explicit TensorObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) { - *static_cast<DLTensor*>(this) = tensor_->dl_tensor; + this->dl_tensor_ptr = &(tensor_->dl_tensor); if (tensor_->dl_tensor.strides == nullptr) { Shape strides = Shape::StridesFromShape(tensor_->dl_tensor.shape, tensor_->dl_tensor.ndim); - this->strides = const_cast<int64_t*>(strides.data()); + dl_tensor_ptr->strides = const_cast<int64_t*>(strides.data()); this->strides_data_ = std::move(strides); } } @@ -292,7 +294,7 @@ class Tensor : public ObjectRef { tvm::ffi::Shape shape() const { TensorObj* obj = get_mutable(); if (!obj->shape_data_.has_value()) { - obj->shape_data_ = tvm::ffi::Shape(obj->shape, obj->shape + obj->ndim); + obj->shape_data_ = tvm::ffi::Shape(obj->dl_tensor_ptr->shape, obj->dl_tensor_ptr->shape + obj->dl_tensor_ptr->ndim); } return *(obj->shape_data_); } @@ -302,9 +304,9 @@ class Tensor : public ObjectRef { */ tvm::ffi::Shape strides() const { TensorObj* obj = get_mutable(); - TVM_FFI_ICHECK(obj->strides != nullptr); + TVM_FFI_ICHECK(obj->dl_tensor_ptr->strides != nullptr); if (!obj->strides_data_.has_value()) { - obj->strides_data_ = tvm::ffi::Shape(obj->strides, obj->strides + obj->ndim); + obj->strides_data_ = tvm::ffi::Shape(obj->dl_tensor_ptr->strides, obj->dl_tensor_ptr->strides + obj->dl_tensor_ptr->ndim); } return *(obj->strides_data_); } @@ -317,13 +319,13 @@ class Tensor : public ObjectRef { * \brief Check if the Tensor is contiguous. * \return True if the Tensor is contiguous, false otherwise. */ - bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); } + bool IsContiguous() const { return tvm::ffi::IsContiguous(*operator->()); } /*! * \brief Check if the Tensor data is aligned to the given alignment. * \param alignment The alignment to check. * \return True if the Tensor data is aligned to the given alignment, false otherwise. */ - bool IsAligned(size_t alignment) const { return tvm::ffi::IsAligned(*get(), alignment); } + bool IsAligned(size_t alignment) const { return tvm::ffi::IsAligned(*operator->(), alignment); } /*! * \brief Create a Tensor from a NDAllocator. * \param alloc The NDAllocator. @@ -449,8 +451,17 @@ class Tensor : public ObjectRef { */ DLManagedTensorVersioned* ToDLPackVersioned() const { return get_mutable()->ToDLPackVersioned(); } + Tensor() = default; + explicit Tensor(::tvm::ffi::ObjectPtr<TensorObj> n) : ObjectRef(n) {} + explicit Tensor(::tvm::ffi::UnsafeInit tag) : ObjectRef(tag) {} + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Tensor) + const DLTensor* operator->() const { return static_cast<TensorObj*>(data_.get())->dl_tensor_ptr; } + const TensorObj* get() const { return static_cast<const TensorObj*>(data_.get()); } + [[maybe_unused]] static constexpr bool _type_is_nullable = true; + using ContainerType = TensorObj; + /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tensor, ObjectRef, TensorObj); + //TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tensor, ObjectRef, TensorObj); /// \endcond protected: diff --git a/src/ffi/extra/structural_hash.cc b/src/ffi/extra/structural_hash.cc index f6463af..b3fa140 100644 --- a/src/ffi/extra/structural_hash.cc +++ b/src/ffi/extra/structural_hash.cc @@ -268,7 +268,7 @@ class StructuralHashHandler { } uint64_t HashTensor(Tensor tensor) { - uint64_t hash_value = details::StableHashCombine(tensor->GetTypeKeyHash(), tensor->ndim); + uint64_t hash_value = details::StableHashCombine(tensor.get()->GetTypeKeyHash(), tensor->ndim); for (int i = 0; i < tensor->ndim; ++i) { hash_value = details::StableHashCombine(hash_value, tensor->shape[i]); }
