This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 46eac564a5 [FFI][ABI] Introduce weak rc support (#18259)
46eac564a5 is described below
commit 46eac564a59fcb66277f2b32bfdc1ddea95cd07c
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Sep 1 08:04:56 2025 -0400
[FFI][ABI] Introduce weak rc support (#18259)
This PR adds weak ref counter support to the FFI ABI.
Weak rc is useful when we want to break cyclic dependencies.
- When a strong rc goes to zero, we call the destructor of the object, but
not freeing the memory
- When both strong and weak rc goes to zero, we call the memory free
operation
The weak rc mechanism is useful when we want to break cyclic dependencies
in object, where the
weak rc can keep memory alive but the destructor is called.
As of now, because we deliberately avoid cyles in codebase, we do not have
strong use-case for weak rc.
However, given weak rc is common practice in shared_ptr, Rust RC, and also
used in torch's c10::intrusive_ptr.
It is better to make sure the ABI is future compatible to such use-cases
before we freeze.
This PR implements weak rc as a u32 counter and strong rc as a u64 counter,
with the following
design consideration.
- Weak rc is very rarely used and u32 is sufficient.
- Keeping weak rc in u32 allows us to keep object header size to 24 bytes,
saving extra 8 bytes(considering alignment)
We also need to update deleter to take flags that consider both weak and
strong deletion events. The implementation tries to optimize common case
where
both strong and weak goes to 0 at the same time and call deleter once
with both flags set.
---
ffi/include/tvm/ffi/c_api.h | 65 ++++-
ffi/include/tvm/ffi/memory.h | 46 ++--
ffi/include/tvm/ffi/object.h | 261 +++++++++++++++++++--
ffi/include/tvm/ffi/type_traits.h | 2 +-
ffi/pyproject.toml | 2 +-
ffi/python/tvm_ffi/cython/base.pxi | 2 +-
ffi/python/tvm_ffi/cython/dtype.pxi | 2 +-
ffi/python/tvm_ffi/cython/object.pxi | 2 +-
ffi/src/ffi/object.cc | 8 +-
ffi/tests/cpp/test_c_ffi_abi.cc | 2 +-
ffi/tests/cpp/test_object.cc | 119 ++++++++++
jvm/native/src/main/native/jni_helper_func.h | 2 +-
.../src/main/native/org_apache_tvm_native_c_api.cc | 2 +-
src/tir/transforms/make_packed_api.cc | 4 +-
web/src/ctypes.ts | 6 +-
web/src/runtime.ts | 8 +-
16 files changed, 475 insertions(+), 58 deletions(-)
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index f099898b15..b4f59526a9 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -156,6 +156,36 @@ typedef enum {
/*! \brief Handle to Object from C API's pov */
typedef void* TVMFFIObjectHandle;
+/*!
+ * \brief bitmask of the object deleter flag.
+ */
+#ifdef __cplusplus
+enum TVMFFIObjectDeleterFlagBitMask : int32_t {
+#else
+typedef enum {
+#endif
+ /*!
+ * \brief deleter action when strong reference count becomes zero.
+ * Need to call destructor of the object but not free the memory block.
+ */
+ kTVMFFIObjectDeleterFlagBitMaskStrong = 1 << 0,
+ /*!
+ * \brief deleter action when weak reference count becomes zero.
+ * Need to free the memory block.
+ */
+ kTVMFFIObjectDeleterFlagBitMaskWeak = 1 << 1,
+ /*!
+ * \brief deleter action when both strong and weak reference counts become
zero.
+ * \note This is the most common case.
+ */
+ kTVMFFIObjectDeleterFlagBitMaskBoth =
+ (kTVMFFIObjectDeleterFlagBitMaskStrong |
kTVMFFIObjectDeleterFlagBitMaskWeak),
+#ifdef __cplusplus
+};
+#else
+} TVMFFIObjectDeleterFlagBitMask;
+#endif
+
/*!
* \brief C-based type of all FFI object header that allocates on heap.
* \note TVMFFIObject and TVMFFIAny share the common type_index header
@@ -166,11 +196,22 @@ typedef struct TVMFFIObject {
* \note The type index of Object and Any are shared in FFI.
*/
int32_t type_index;
- /*! \brief Reference counter of the object. */
- int32_t ref_counter;
+ /*!
+ * \brief Weak reference counter of the object, for compatiblity with
weak_ptr design.
+ * \note Use u32 to ensure that overall object stays within 24-byte
boundary, usually
+ * manipulation of weak counter is less common than strong counter.
+ */
+ uint32_t weak_ref_count;
+ /*! \brief Strong reference counter of the object. */
+ uint64_t strong_ref_count;
union {
- /*! \brief Deleter to be invoked when reference counter goes to zero. */
- void (*deleter)(struct TVMFFIObject* self);
+ /*!
+ * \brief Deleter to be invoked when strong reference counter goes to zero.
+ * \param self The self object handle.
+ * \param flags The flags to indicate deletion behavior.
+ * \sa TVMFFIObjectDeleterFlagBitMask
+ */
+ void (*deleter)(struct TVMFFIObject* self, int flags);
/*!
* \brief auxilary field to TVMFFIObject is always 8 bytes aligned.
* \note This helps us to ensure cross platform compatibility.
@@ -307,13 +348,19 @@ typedef struct {
// Section: Basic object API
//------------------------------------------------------------
/*!
- * \brief Free an object handle by decreasing reference
+ * \brief Increas the strong reference count of an object handle
+ * \param obj The object handle.
+ * \note Internally we increase the reference counter of the object.
+ * \return 0 when success, nonzero when failure happens
+ */
+TVM_FFI_DLL int TVMFFIObjectIncRef(TVMFFIObjectHandle obj);
+
+/*!
+ * \brief Free an object handle by decreasing strong reference
* \param obj The object handle.
- * \note Internally we decrease the reference counter of the object.
- * The object will be freed when every reference to the object are
removed.
* \return 0 when success, nonzero when failure happens
*/
-TVM_FFI_DLL int TVMFFIObjectFree(TVMFFIObjectHandle obj);
+TVM_FFI_DLL int TVMFFIObjectDecRef(TVMFFIObjectHandle obj);
/*!
* \brief Convert type key to type index.
@@ -470,7 +517,7 @@ TVM_FFI_DLL int TVMFFIDataTypeFromString(const
TVMFFIByteArray* str, DLDataType*
* \param dtype The DLDataType to convert.
* \param out The output string.
* \return 0 when success, nonzero when failure happens
-* \note out is a String object that needs to be freed by the caller via
TVMFFIObjectFree.
+* \note out is a String object that needs to be freed by the caller via
TVMFFIObjectDecRef.
The content of string can be accessed via TVMFFIObjectGetByteArrayPtr.
* \note The input dtype is a pointer to the DLDataType to avoid ABI
compatibility issues.
diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h
index 02537df79c..533d000427 100644
--- a/ffi/include/tvm/ffi/memory.h
+++ b/ffi/include/tvm/ffi/memory.h
@@ -33,7 +33,7 @@ namespace tvm {
namespace ffi {
/*! \brief Deleter function for obeject */
-typedef void (*FObjectDeleter)(TVMFFIObject* obj);
+typedef void (*FObjectDeleter)(TVMFFIObject* obj, int flags);
/*!
* \brief Allocate an object using default allocator.
@@ -75,7 +75,8 @@ class ObjAllocatorBase {
static_assert(std::is_base_of<Object, T>::value, "make can only be used to
create Object");
T* ptr = Handler::New(static_cast<Derived*>(this),
std::forward<Args>(args)...);
TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr);
- ffi_ptr->ref_counter = 1;
+ ffi_ptr->strong_ref_count = 1;
+ ffi_ptr->weak_ref_count = 1;
ffi_ptr->type_index = T::RuntimeTypeIndex();
ffi_ptr->deleter = Handler::Deleter();
return details::ObjectUnsafe::ObjectPtrFromOwned<T>(ptr);
@@ -96,7 +97,8 @@ class ObjAllocatorBase {
ArrayType* ptr =
Handler::New(static_cast<Derived*>(this), num_elems,
std::forward<Args>(args)...);
TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr);
- ffi_ptr->ref_counter = 1;
+ ffi_ptr->strong_ref_count = 1;
+ ffi_ptr->weak_ref_count = 1;
ffi_ptr->type_index = ArrayType::RuntimeTypeIndex();
ffi_ptr->deleter = Handler::Deleter();
return details::ObjectUnsafe::ObjectPtrFromOwned<ArrayType>(ptr);
@@ -136,14 +138,18 @@ class SimpleObjAllocator : public
ObjAllocatorBase<SimpleObjAllocator> {
static FObjectDeleter Deleter() { return Deleter_; }
private:
- static void Deleter_(TVMFFIObject* objptr) {
+ static void Deleter_(TVMFFIObject* objptr, int flags) {
T* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned<T>(objptr);
- // It is important to do tptr->T::~T(),
- // so that we explicitly call the specific destructor
- // instead of tptr->~T(), which could mean the intention
- // call a virtual destructor(which may not be available and is not
required).
- tptr->T::~T();
- delete reinterpret_cast<StorageType*>(tptr);
+ if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) {
+ // It is important to do tptr->T::~T(),
+ // so that we explicitly call the specific destructor
+ // instead of tptr->~T(), which could mean the intention
+ // call a virtual destructor(which may not be available and is not
required).
+ tptr->T::~T();
+ }
+ if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) {
+ delete reinterpret_cast<StorageType*>(tptr);
+ }
}
};
@@ -182,15 +188,19 @@ class SimpleObjAllocator : public
ObjAllocatorBase<SimpleObjAllocator> {
static FObjectDeleter Deleter() { return Deleter_; }
private:
- static void Deleter_(TVMFFIObject* objptr) {
+ static void Deleter_(TVMFFIObject* objptr, int flags) {
ArrayType* tptr =
details::ObjectUnsafe::RawObjectPtrFromUnowned<ArrayType>(objptr);
- // It is important to do tptr->ArrayType::~ArrayType(),
- // so that we explicitly call the specific destructor
- // instead of tptr->~ArrayType(), which could mean the intention
- // call a virtual destructor(which may not be available and is not
required).
- tptr->ArrayType::~ArrayType();
- StorageType* p = reinterpret_cast<StorageType*>(tptr);
- delete[] p;
+ if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) {
+ // It is important to do tptr->ArrayType::~ArrayType(),
+ // so that we explicitly call the specific destructor
+ // instead of tptr->~ArrayType(), which could mean the intention
+ // call a virtual destructor(which may not be available and is not
required).
+ tptr->ArrayType::~ArrayType();
+ }
+ if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) {
+ StorageType* p = reinterpret_cast<StorageType*>(tptr);
+ delete[] p;
+ }
}
};
};
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index cf282a6e27..cc5ee8d945 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -143,7 +143,8 @@ class Object {
public:
Object() {
- header_.ref_counter = 0;
+ header_.strong_ref_count = 0;
+ header_.weak_ref_count = 0;
header_.deleter = nullptr;
}
/*!
@@ -197,9 +198,9 @@ class Object {
int32_t use_count() const {
// only need relaxed load of counters
#ifdef _MSC_VER
- return (reinterpret_cast<const volatile long*>(&header_.ref_counter))[0];
// NOLINT(*)
+ return (reinterpret_cast<const volatile
__int64*>(&header_.strong_ref_count))[0]; // NOLINT(*)
#else
- return __atomic_load_n(&(header_.ref_counter), __ATOMIC_RELAXED);
+ return __atomic_load_n(&(header_.strong_ref_count), __ATOMIC_RELAXED);
#endif
}
@@ -230,33 +231,121 @@ class Object {
static int32_t _GetOrAllocRuntimeTypeIndex() { return
TypeIndex::kTVMFFIObject; }
private:
- /*! \brief increase reference count */
+ /*! \brief increase strong reference count, the caller must already hold a
strong reference */
void IncRef() {
#ifdef _MSC_VER
- _InterlockedIncrement(reinterpret_cast<volatile
long*>(&header_.ref_counter)); // NOLINT(*)
+ _InterlockedIncrement64(
+ reinterpret_cast<volatile __int64*>(&header_.strong_ref_count)); //
NOLINT(*)
#else
- __atomic_fetch_add(&(header_.ref_counter), 1, __ATOMIC_RELAXED);
+ __atomic_fetch_add(&(header_.strong_ref_count), 1, __ATOMIC_RELAXED);
+#endif
+ }
+ /*!
+ * \brief Try to lock the object to increase the strong reference count,
+ * the caller must already hold a strong reference.
+ * \return whether the lock call is successful and object is still alive.
+ */
+ bool TryPromoteWeakPtr() {
+#ifdef _MSC_VER
+ uint64_t old_count =
+ (reinterpret_cast<const volatile
__int64*>(&header_.strong_ref_count))[0]; // NOLINT(*)
+ while (old_count > 0) {
+ uint64_t new_count = old_count + 1;
+ uint64_t old_count_loaded = _InterlockedCompareExchange64(
+ reinterpret_cast<volatile __int64*>(&header_.strong_ref_count),
new_count, old_count);
+ if (old_count == old_count_loaded) {
+ return true;
+ }
+ old_count = old_count_loaded;
+ }
+ return false;
+#else
+ uint64_t old_count = __atomic_load_n(&(header_.strong_ref_count),
__ATOMIC_RELAXED);
+ while (old_count > 0) {
+ // must do CAS to ensure that we are the only one that increases the
reference count
+ // avoid condition when two threads tries to promote weak to strong at
same time
+ // or when strong deletion happens between the load and the CAS
+ uint64_t new_count = old_count + 1;
+ if (__atomic_compare_exchange_n(&(header_.strong_ref_count), &old_count,
new_count, true,
+ __ATOMIC_ACQ_REL, __ATOMIC_RELAXED)) {
+ return true;
+ }
+ }
+ return false;
+#endif
+ }
+
+ /*! \brief increase weak reference count */
+ void IncWeakRef() {
+#ifdef _MSC_VER
+ _InterlockedIncrement(reinterpret_cast<volatile
long*>(&header_.weak_ref_count)); // NOLINT(*)
+#else
+ __atomic_fetch_add(&(header_.weak_ref_count), 1, __ATOMIC_RELAXED);
#endif
}
- /*! \brief decrease reference count and delete the object */
+ /*! \brief decrease strong reference count and delete the object */
void DecRef() {
#ifdef _MSC_VER
- if (_InterlockedDecrement( //
- reinterpret_cast<volatile long*>(&header_.ref_counter)) == 0) {
// NOLINT(*)
+ // use simpler impl in windows to ensure correctness
+ if (_InterlockedDecrement64(
//
+ reinterpret_cast<volatile __int64*>(&header_.strong_ref_count)) ==
0) { // NOLINT(*)
// full barrrier is implicit in InterlockedDecrement
if (header_.deleter != nullptr) {
- header_.deleter(&(this->header_));
+ header_.deleter(&(this->header_),
kTVMFFIObjectDeleterFlagBitMaskStrong);
+ }
+ if (_InterlockedDecrement(
//
+ reinterpret_cast<volatile long*>(&header_.weak_ref_count)) == 0)
{ // NOLINT(*)
+ if (header_.deleter != nullptr) {
+ header_.deleter(&(this->header_),
kTVMFFIObjectDeleterFlagBitMaskWeak);
+ }
}
}
#else
// first do a release, note we only need to acquire for deleter
- if (__atomic_fetch_sub(&(header_.ref_counter), 1, __ATOMIC_RELEASE) == 1) {
- // only acquire when we need to call deleter
- // in this case we need to ensure all previous writes are visible
+ if (__atomic_fetch_sub(&(header_.strong_ref_count), 1, __ATOMIC_RELEASE)
== 1) {
+ if (__atomic_load_n(&(header_.weak_ref_count), __ATOMIC_RELAXED) == 1) {
+ // common case, we need to delete both the object and the memory block
+ // only acquire when we need to call deleter
+ __atomic_thread_fence(__ATOMIC_ACQUIRE);
+ if (header_.deleter != nullptr) {
+ // call deleter once
+ header_.deleter(&(this->header_),
kTVMFFIObjectDeleterFlagBitMaskBoth);
+ }
+ } else {
+ // Slower path: there is still a weak reference left
+ __atomic_thread_fence(__ATOMIC_ACQUIRE);
+ // call destructor first, then decrease weak reference count
+ if (header_.deleter != nullptr) {
+ header_.deleter(&(this->header_),
kTVMFFIObjectDeleterFlagBitMaskStrong);
+ }
+ // now decrease weak reference count
+ if (__atomic_fetch_sub(&(header_.weak_ref_count), 1, __ATOMIC_RELEASE)
== 1) {
+ __atomic_thread_fence(__ATOMIC_ACQUIRE);
+ if (header_.deleter != nullptr) {
+ header_.deleter(&(this->header_),
kTVMFFIObjectDeleterFlagBitMaskWeak);
+ }
+ }
+ }
+ }
+#endif
+ }
+
+ /*! \brief decrease weak reference count */
+ void DecWeakRef() {
+#ifdef _MSC_VER
+ if (_InterlockedDecrement(
//
+ reinterpret_cast<volatile long*>(&header_.weak_ref_count)) == 0) {
// NOLINT(*)
+ if (header_.deleter != nullptr) {
+ header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak);
+ }
+ }
+#else
+ // now decrease weak reference count
+ if (__atomic_fetch_sub(&(header_.weak_ref_count), 1, __ATOMIC_RELEASE) ==
1) {
__atomic_thread_fence(__ATOMIC_ACQUIRE);
if (header_.deleter != nullptr) {
- header_.deleter(&(this->header_));
+ header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak);
}
}
#endif
@@ -265,6 +354,8 @@ class Object {
// friend classes
template <typename>
friend class ObjectPtr;
+ template <typename>
+ friend class WeakObjectPtr;
friend struct tvm::ffi::details::ObjectUnsafe;
};
@@ -402,6 +493,148 @@ class ObjectPtr {
friend struct ObjectPtrHash;
template <typename>
friend class ObjectPtr;
+ template <typename>
+ friend class WeakObjectPtr;
+ friend struct tvm::ffi::details::ObjectUnsafe;
+};
+
+/*!
+ * \brief A custom smart pointer for Object.
+ * \tparam T the content data type.
+ * \sa make_object
+ */
+template <typename T>
+class WeakObjectPtr {
+ public:
+ /*! \brief default constructor */
+ WeakObjectPtr() {}
+ /*! \brief default constructor */
+ WeakObjectPtr(std::nullptr_t) {} // NOLINT(*)
+ /*!
+ * \brief copy constructor
+ * \param other The value to be moved
+ */
+ WeakObjectPtr(const WeakObjectPtr<T>& other) // NOLINT(*)
+ : WeakObjectPtr(other.data_) {}
+
+ /*!
+ * \brief copy constructor
+ * \param other The value to be moved
+ */
+ WeakObjectPtr(const ObjectPtr<T>& other) // NOLINT(*)
+ : WeakObjectPtr(other.get()) {}
+ /*!
+ * \brief copy constructor
+ * \param other The value to be moved
+ */
+ template <typename U>
+ WeakObjectPtr(const WeakObjectPtr<U>& other) // NOLINT(*)
+ : WeakObjectPtr(other.data_) {
+ static_assert(std::is_base_of<T, U>::value,
+ "can only assign of child class ObjectPtr to parent");
+ }
+ /*!
+ * \brief copy constructor
+ * \param other The value to be moved
+ */
+ template <typename U>
+ WeakObjectPtr(const ObjectPtr<U>& other) // NOLINT(*)
+ : WeakObjectPtr(other.data_) {
+ static_assert(std::is_base_of<T, U>::value,
+ "can only assign of child class ObjectPtr to parent");
+ }
+ /*!
+ * \brief move constructor
+ * \param other The value to be moved
+ */
+ WeakObjectPtr(WeakObjectPtr<T>&& other) // NOLINT(*)
+ : data_(other.data_) {
+ other.data_ = nullptr;
+ }
+ /*!
+ * \brief move constructor
+ * \param other The value to be moved
+ */
+ template <typename Y>
+ WeakObjectPtr(WeakObjectPtr<Y>&& other) // NOLINT(*)
+ : data_(other.data_) {
+ static_assert(std::is_base_of<T, Y>::value,
+ "can only assign of child class ObjectPtr to parent");
+ other.data_ = nullptr;
+ }
+ /*! \brief destructor */
+ ~WeakObjectPtr() { this->reset(); }
+ /*!
+ * \brief Swap this array with another Object
+ * \param other The other Object
+ */
+ void swap(WeakObjectPtr<T>& other) { // NOLINT(*)
+ std::swap(data_, other.data_);
+ }
+
+ /*!
+ * \brief copy assignment
+ * \param other The value to be assigned.
+ * \return reference to self.
+ */
+ WeakObjectPtr<T>& operator=(const WeakObjectPtr<T>& other) { // NOLINT(*)
+ // takes in plane operator to enable copy elison.
+ // copy-and-swap idiom
+ WeakObjectPtr(other).swap(*this); // NOLINT(*)
+ return *this;
+ }
+ /*!
+ * \brief move assignment
+ * \param other The value to be assigned.
+ * \return reference to self.
+ */
+ WeakObjectPtr<T>& operator=(WeakObjectPtr<T>&& other) { // NOLINT(*)
+ // copy-and-swap idiom
+ WeakObjectPtr(std::move(other)).swap(*this); // NOLINT(*)
+ return *this;
+ }
+
+ /*! \return The internal object pointer if the object is still alive,
otherwise nullptr */
+ ObjectPtr<T> lock() const {
+ if (data_ != nullptr && data_->TryPromoteWeakPtr()) {
+ ObjectPtr<T> ret;
+ // we already increase the reference count, so we don't need to do it
again
+ ret.data_ = data_;
+ return ret;
+ }
+ return nullptr;
+ }
+
+ /*! \brief reset the content of ptr to be nullptr */
+ void reset() {
+ if (data_ != nullptr) {
+ data_->DecWeakRef();
+ data_ = nullptr;
+ }
+ }
+
+ /*! \return The use count of the ptr, for debug purposes */
+ int use_count() const { return data_ != nullptr ? data_->use_count() : 0; }
+
+ /*! \return whether the pointer is nullptr */
+ bool expired() const { return data_ == nullptr || data_->use_count() == 0; }
+
+ private:
+ /*! \brief internal pointer field */
+ Object* data_{nullptr};
+
+ /*!
+ * \brief constructor from Object
+ * \param data The data pointer
+ */
+ explicit WeakObjectPtr(Object* data) : data_(data) {
+ if (data_ != nullptr) {
+ data_->IncWeakRef();
+ }
+ }
+
+ template <typename>
+ friend class WeakObjectPtr;
friend struct tvm::ffi::details::ObjectUnsafe;
};
diff --git a/ffi/include/tvm/ffi/type_traits.h
b/ffi/include/tvm/ffi/type_traits.h
index b019935a6c..9cdb2b9338 100644
--- a/ffi/include/tvm/ffi/type_traits.h
+++ b/ffi/include/tvm/ffi/type_traits.h
@@ -472,7 +472,7 @@ struct TypeTraits<DLTensor*> : public TypeTraitsBase {
} else if (src->type_index == TypeIndex::kTVMFFINDArray) {
// Conversion from NDArray pointer to DLTensor
// based on the assumption that NDArray always follows the TVMFFIObject
header
- static_assert(sizeof(TVMFFIObject) == 16, "TVMFFIObject must be 8
bytes");
+ static_assert(sizeof(TVMFFIObject) == 24);
return reinterpret_cast<DLTensor*>(reinterpret_cast<char*>(src->v_obj) +
sizeof(TVMFFIObject));
}
diff --git a/ffi/pyproject.toml b/ffi/pyproject.toml
index 8ed9e275e2..083a60fc36 100644
--- a/ffi/pyproject.toml
+++ b/ffi/pyproject.toml
@@ -17,7 +17,7 @@
[project]
name = "apache-tvm-ffi"
-version = "0.1.0a5"
+version = "0.1.0a6"
description = "tvm ffi"
authors = [{ name = "TVM FFI team" }]
diff --git a/ffi/python/tvm_ffi/cython/base.pxi
b/ffi/python/tvm_ffi/cython/base.pxi
index 14b3d97f52..4a47efd773 100644
--- a/ffi/python/tvm_ffi/cython/base.pxi
+++ b/ffi/python/tvm_ffi/cython/base.pxi
@@ -171,7 +171,7 @@ cdef extern from "tvm/ffi/c_api.h":
const TVMFFIMethodInfo* methods
const TVMFFITypeMetadata* metadata
- int TVMFFIObjectFree(TVMFFIObjectHandle obj) nogil
+ int TVMFFIObjectDecRef(TVMFFIObjectHandle obj) nogil
int TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) nogil
int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny* args, int32_t
num_args,
TVMFFIAny* result) nogil
diff --git a/ffi/python/tvm_ffi/cython/dtype.pxi
b/ffi/python/tvm_ffi/cython/dtype.pxi
index 279b17f8c8..d9e20b77f3 100644
--- a/ffi/python/tvm_ffi/cython/dtype.pxi
+++ b/ffi/python/tvm_ffi/cython/dtype.pxi
@@ -104,7 +104,7 @@ cdef class DataType:
bytes_ptr = TVMFFIBytesGetByteArrayPtr(temp_any.v_obj)
res = py_str(PyBytes_FromStringAndSize(bytes_ptr.data, bytes_ptr.size))
- CHECK_CALL(TVMFFIObjectFree(temp_any.v_obj))
+ CHECK_CALL(TVMFFIObjectDecRef(temp_any.v_obj))
return res
diff --git a/ffi/python/tvm_ffi/cython/object.pxi
b/ffi/python/tvm_ffi/cython/object.pxi
index dad6bee51b..1203f0c682 100644
--- a/ffi/python/tvm_ffi/cython/object.pxi
+++ b/ffi/python/tvm_ffi/cython/object.pxi
@@ -78,7 +78,7 @@ cdef class Object:
def __dealloc__(self):
if self.chandle != NULL:
- CHECK_CALL(TVMFFIObjectFree(self.chandle))
+ CHECK_CALL(TVMFFIObjectDecRef(self.chandle))
self.chandle = NULL
def __ctypes_handle__(self):
diff --git a/ffi/src/ffi/object.cc b/ffi/src/ffi/object.cc
index 61107cb63f..f96636fd49 100644
--- a/ffi/src/ffi/object.cc
+++ b/ffi/src/ffi/object.cc
@@ -388,12 +388,18 @@ class TypeTable {
} // namespace ffi
} // namespace tvm
-int TVMFFIObjectFree(TVMFFIObjectHandle handle) {
+int TVMFFIObjectDecRef(TVMFFIObjectHandle handle) {
TVM_FFI_SAFE_CALL_BEGIN();
tvm::ffi::details::ObjectUnsafe::DecRefObjectHandle(handle);
TVM_FFI_SAFE_CALL_END();
}
+int TVMFFIObjectIncRef(TVMFFIObjectHandle handle) {
+ TVM_FFI_SAFE_CALL_BEGIN();
+ tvm::ffi::details::ObjectUnsafe::IncRefObjectHandle(handle);
+ TVM_FFI_SAFE_CALL_END();
+}
+
int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t* out_tindex)
{
TVM_FFI_SAFE_CALL_BEGIN();
out_tindex[0] = tvm::ffi::TypeTable::Global()->TypeKeyToIndex(type_key);
diff --git a/ffi/tests/cpp/test_c_ffi_abi.cc b/ffi/tests/cpp/test_c_ffi_abi.cc
index 1efceef297..e6c6116edd 100644
--- a/ffi/tests/cpp/test_c_ffi_abi.cc
+++ b/ffi/tests/cpp/test_c_ffi_abi.cc
@@ -25,7 +25,7 @@ TEST(ABIHeaderAlignment, Default) {
TVMFFIObject value;
value.type_index = 10;
EXPECT_EQ(reinterpret_cast<TVMFFIAny*>(&value)->type_index, 10);
- static_assert(sizeof(TVMFFIObject) == 16, "TVMFFIObject must be 16 bytes");
+ static_assert(sizeof(TVMFFIObject) == 24);
}
} // namespace
diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc
index 4b53a70b42..f6bedcb6f3 100644
--- a/ffi/tests/cpp/test_object.cc
+++ b/ffi/tests/cpp/test_object.cc
@@ -103,4 +103,123 @@ TEST(Object, CAPIAccessor) {
int32_t type_index = TVMFFIObjectGetTypeIndex(obj);
EXPECT_EQ(type_index, TIntObj::RuntimeTypeIndex());
}
+
+TEST(Object, WeakObjectPtr) {
+ // Test basic construction from ObjectPtr
+ ObjectPtr<TIntObj> strong_ptr = make_object<TIntObj>(42);
+ WeakObjectPtr<TIntObj> weak_ptr(strong_ptr);
+
+ EXPECT_EQ(strong_ptr.use_count(), 1);
+ EXPECT_FALSE(weak_ptr.expired());
+ EXPECT_EQ(weak_ptr.use_count(), 1);
+
+ // Test lock() when object is still alive
+ ObjectPtr<TIntObj> locked_ptr = weak_ptr.lock();
+ EXPECT_TRUE(locked_ptr != nullptr);
+ EXPECT_EQ(locked_ptr->value, 42);
+ EXPECT_EQ(strong_ptr.use_count(), 2);
+ EXPECT_EQ(weak_ptr.use_count(), 2);
+
+ // Test lock() when object is expired
+ strong_ptr.reset();
+ locked_ptr.reset();
+ EXPECT_TRUE(weak_ptr.expired());
+ EXPECT_EQ(weak_ptr.use_count(), 0);
+
+ ObjectPtr<TIntObj> expired_lock = weak_ptr.lock();
+ EXPECT_TRUE(expired_lock == nullptr);
+}
+
+TEST(Object, WeakObjectPtrAssignment) {
+ // Test copy construction
+ ObjectPtr<TIntObj> new_strong = make_object<TIntObj>(100);
+ WeakObjectPtr<TIntObj> weak1(new_strong);
+ WeakObjectPtr<TIntObj> weak2(weak1);
+
+ EXPECT_EQ(new_strong.use_count(), 1);
+ EXPECT_FALSE(weak1.expired());
+ EXPECT_FALSE(weak2.expired());
+ EXPECT_EQ(weak1.use_count(), 1);
+ EXPECT_EQ(weak2.use_count(), 1);
+
+ // Test move construction
+ WeakObjectPtr<TIntObj> weak3(std::move(weak1));
+ EXPECT_TRUE(weak1.expired()); // weak1 should be moved from
+ EXPECT_FALSE(weak3.expired());
+ EXPECT_EQ(weak3.use_count(), 1);
+
+ // Test assignment
+ WeakObjectPtr<TIntObj> weak4;
+ weak4 = weak2;
+ EXPECT_FALSE(weak2.expired());
+ EXPECT_FALSE(weak4.expired());
+ EXPECT_EQ(weak2.use_count(), 1);
+ EXPECT_EQ(weak4.use_count(), 1);
+
+ // Test move assignment
+ WeakObjectPtr<TIntObj> weak5;
+ weak5 = std::move(weak2);
+ EXPECT_TRUE(weak2.expired()); // weak2 should be moved from
+ EXPECT_FALSE(weak5.expired());
+ EXPECT_EQ(weak5.use_count(), 1);
+
+ // Test reset()
+ weak3.reset();
+ EXPECT_TRUE(weak3.expired());
+ EXPECT_EQ(weak3.use_count(), 0);
+
+ // Test swap()
+ ObjectPtr<TIntObj> strong_a = make_object<TIntObj>(200);
+ ObjectPtr<TIntObj> strong_b = make_object<TIntObj>(300);
+ WeakObjectPtr<TIntObj> weak_a(strong_a);
+ WeakObjectPtr<TIntObj> weak_b(strong_b);
+
+ weak_a.swap(weak_b);
+ EXPECT_EQ(weak_a.lock()->value, 300);
+ EXPECT_EQ(weak_b.lock()->value, 200);
+
+ // Test construction from nullptr
+ WeakObjectPtr<TIntObj> null_weak(nullptr);
+ EXPECT_TRUE(null_weak.expired());
+ EXPECT_EQ(null_weak.use_count(), 0);
+ EXPECT_TRUE(null_weak.lock() == nullptr);
+
+ // Test inheritance compatibility
+ ObjectPtr<TNumberObj> number_ptr = make_object<TIntObj>(500);
+ WeakObjectPtr<TNumberObj> number_weak(number_ptr);
+
+ EXPECT_FALSE(number_weak.expired());
+ EXPECT_EQ(number_weak.use_count(), 1);
+
+ // Test that weak references don't prevent object deletion
+ ObjectPtr<TIntObj> temp_strong = make_object<TIntObj>(999);
+ WeakObjectPtr<TIntObj> temp_weak(temp_strong);
+
+ EXPECT_FALSE(temp_weak.expired());
+ temp_strong.reset();
+ EXPECT_TRUE(temp_weak.expired());
+ EXPECT_TRUE(temp_weak.lock() == nullptr);
+
+ // Test multiple weak references
+ ObjectPtr<TIntObj> multi_strong = make_object<TIntObj>(777);
+ WeakObjectPtr<TIntObj> multi_weak1(multi_strong);
+ WeakObjectPtr<TIntObj> multi_weak2(multi_strong);
+ WeakObjectPtr<TIntObj> multi_weak3(multi_strong);
+
+ EXPECT_EQ(multi_strong.use_count(), 1);
+ EXPECT_FALSE(multi_weak1.expired());
+ EXPECT_FALSE(multi_weak2.expired());
+ EXPECT_FALSE(multi_weak3.expired());
+
+ // All weak references should be able to lock
+ ObjectPtr<TIntObj> lock1 = multi_weak1.lock();
+ ObjectPtr<TIntObj> lock2 = multi_weak2.lock();
+ ObjectPtr<TIntObj> lock3 = multi_weak3.lock();
+
+ EXPECT_EQ(multi_strong.use_count(), 4);
+ EXPECT_EQ(lock1->value, 777);
+ EXPECT_EQ(lock2->value, 777);
+ EXPECT_EQ(lock3->value, 777);
+}
+
} // namespace
diff --git a/jvm/native/src/main/native/jni_helper_func.h
b/jvm/native/src/main/native/jni_helper_func.h
index 5db3e279cf..9b50fb6a49 100644
--- a/jvm/native/src/main/native/jni_helper_func.h
+++ b/jvm/native/src/main/native/jni_helper_func.h
@@ -236,7 +236,7 @@ jobject tvmRetValueToJava(JNIEnv* env, TVMFFIAny value) {
}
case TypeIndex::kTVMFFIBytes: {
jobject ret = newTVMValueBytes(env,
TVMFFIBytesGetByteArrayPtr(value.v_obj));
- TVMFFIObjectFree(value.v_obj);
+ TVMFFIObjectDecRef(value.v_obj);
return ret;
}
default: {
diff --git a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc
b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc
index 3ebe7fddfa..b512ec8775 100644
--- a/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc
+++ b/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc
@@ -322,7 +322,7 @@ JNIEXPORT jint JNICALL
Java_org_apache_tvm_LibInfo_tvmFFIFunctionSetGlobal(JNIEn
// Module
JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmFFIObjectFree(JNIEnv*
env, jobject obj,
jlong
jhandle) {
- return TVMFFIObjectFree(reinterpret_cast<TVMFFIObjectHandle>(jhandle));
+ return TVMFFIObjectDecRef(reinterpret_cast<TVMFFIObjectHandle>(jhandle));
}
// NDArray
diff --git a/src/tir/transforms/make_packed_api.cc
b/src/tir/transforms/make_packed_api.cc
index 7477fe8636..e6c6e9aa02 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -299,10 +299,12 @@ PrimFunc MakePackedAPI(PrimFunc func) {
tvm::tir::StringImm(msg.str()), nop));
// if type_index is NDArray, we need to add the offset of the DLTensor
header
// which always equals 16 bytes, this ensures that T.handle always shows
up as a DLTensor*
+ const int64_t object_cell_offset = sizeof(TVMFFIObject);
+ static_assert(object_cell_offset == 24);
arg_value = f_load_arg_value(param.dtype(), i);
PrimExpr handle_from_ndarray =
Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(),
- {arg_value, IntImm(DataType::Int(32), 16)});
+ {arg_value, IntImm(DataType::Int(32), object_cell_offset)});
arg_value =
Select(type_index == ffi::TypeIndex::kTVMFFINDArray,
handle_from_ndarray, arg_value);
} else if (dtype.is_bool()) {
diff --git a/web/src/ctypes.ts b/web/src/ctypes.ts
index d2ecf4b944..9836fbfda5 100644
--- a/web/src/ctypes.ts
+++ b/web/src/ctypes.ts
@@ -41,7 +41,7 @@ export const enum SizeOf {
TVMFFIAny = 8 * 2,
DLDataType = I32,
DLDevice = I32 + I32,
- ObjectHeader = 8 * 2,
+ ObjectHeader = 8 * 3,
}
//---------------The new TVM FFI---------------
@@ -142,9 +142,9 @@ export type FTVMFFIWasmFunctionCreate = (
export type FTVMFFIWasmFunctionDeleter = (self: Pointer) => void;
/**
- * int TVMFFIObjectFree(TVMFFIObjectHandle obj);
+ * int TVMFFIObjectDecRef(TVMFFIObjectHandle obj);
*/
-export type FTVMFFIObjectFree = (obj: Pointer) => number;
+export type FTVMFFIObjectDecRef = (obj: Pointer) => number;
/**
* int TVMFFITypeKeyToIndex(const TVMFFIByteArray* type_key, int32_t*
out_tindex);
diff --git a/web/src/runtime.ts b/web/src/runtime.ts
index 071b2eed68..3720b1873e 100644
--- a/web/src/runtime.ts
+++ b/web/src/runtime.ts
@@ -450,7 +450,7 @@ export class TVMObject implements Disposable {
dispose(): void {
if (this.handle != 0) {
this.lib.checkCall(
- (this.lib.exports.TVMFFIObjectFree as
ctypes.FTVMFFIObjectFree)(this.handle)
+ (this.lib.exports.TVMFFIObjectDecRef as
ctypes.FTVMFFIObjectDecRef)(this.handle)
);
this.handle = 0;
}
@@ -2253,7 +2253,7 @@ export class Instance implements Disposable {
const strObjPtr = this.memory.loadPointer(valuePtr);
const result = this.memory.loadByteArrayAsString(strObjPtr +
SizeOf.ObjectHeader);
this.lib.checkCall(
- (this.lib.exports.TVMFFIObjectFree as
ctypes.FTVMFFIObjectFree)(strObjPtr)
+ (this.lib.exports.TVMFFIObjectDecRef as
ctypes.FTVMFFIObjectDecRef)(strObjPtr)
);
return result;
}
@@ -2264,7 +2264,7 @@ export class Instance implements Disposable {
const strObjPtr = this.memory.loadPointer(valuePtr);
const result = this.memory.loadByteArrayAsString(strObjPtr +
SizeOf.ObjectHeader);
this.lib.checkCall(
- (this.lib.exports.TVMFFIObjectFree as
ctypes.FTVMFFIObjectFree)(strObjPtr)
+ (this.lib.exports.TVMFFIObjectDecRef as
ctypes.FTVMFFIObjectDecRef)(strObjPtr)
);
return result;
}
@@ -2275,7 +2275,7 @@ export class Instance implements Disposable {
const bytesObjPtr = this.memory.loadPointer(valuePtr);
const result = this.memory.loadByteArrayAsBytes(bytesObjPtr +
SizeOf.ObjectHeader);
this.lib.checkCall(
- (this.lib.exports.TVMFFIObjectFree as
ctypes.FTVMFFIObjectFree)(bytesObjPtr)
+ (this.lib.exports.TVMFFIObjectDecRef as
ctypes.FTVMFFIObjectDecRef)(bytesObjPtr)
);
return result;
}