This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 43d13e8 [CORE] Update logic to use combined ref count (#58)
43d13e8 is described below
commit 43d13e86ee24d1558f929e3b0faa3182ca1af872
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Sep 26 13:05:50 2025 -0400
[CORE] Update logic to use combined ref count (#58)
This PR follows the previous PR to update the change to use
combined ref count optimization.
---
docs/concepts/abi_overview.md | 14 +++--
include/tvm/ffi/c_api.h | 27 +++++---
include/tvm/ffi/memory.h | 6 +-
include/tvm/ffi/object.h | 136 ++++++++++++++++++++++++++---------------
pyproject.toml | 2 +-
python/tvm_ffi/cython/base.pxi | 5 +-
6 files changed, 119 insertions(+), 71 deletions(-)
diff --git a/docs/concepts/abi_overview.md b/docs/concepts/abi_overview.md
index b93397a..47639b5 100644
--- a/docs/concepts/abi_overview.md
+++ b/docs/concepts/abi_overview.md
@@ -191,8 +191,7 @@ we adopt a unified object storage format, defined as
follows:
```c++
typedef struct TVMFFIObject {
- uint32_t strong_ref_count;
- uint32_t weak_ref_count;
+ uint64_t combined_ref_count;
int32_t type_index;
uint32_t __padding;
union {
@@ -204,13 +203,16 @@ typedef struct TVMFFIObject {
`TVMFFIObject` defines a common 24-byte intrusive header that all in-memory
objects share:
-- `strong_ref_count` stores the strong atomic reference counter of the object.
-- `weak_ref_count` stores the weak atomic reference counter of the object.
+- `combined_ref_count` packs strong and weak reference counter of the object
into a single 64bit field
+ - The lower 32bits stores the strong atomic reference counter:
+ `strong_ref_count = combined_ref_count & 0xFFFFFFFF`
+ - The higher 32bits stores the weak atomic reference counter:
+ `weak_ref_count = (combined_ref_count >> 32) & 0xFFFFFFFF`
- `type_index` helps us identify the type being stored, which is consistent
with `TVMFFIAny.type_index`.
- `deleter` should be called when either the strong or weak ref counter goes
to zero.
- The flags are set to indicate the event of either weak or strong going to
zero, or both.
- - When `strong_ref_count` gets to zero, the deleter needs to call the
destructor of the object.
- - When `weak_ref_count` gets to zero, the deleter needs to free the memory
allocated by self.
+ - When strong reference counter gets to zero, the deleter needs to call the
destructor of the object.
+ - When weak reference counter gets to zero, the deleter needs to free the
memory allocated by self.
**Rationales:** There are several considerations when designing the data
structure:
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index d2b9fab..0cac1f7 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -219,17 +219,26 @@ typedef enum {
* \brief C-based type of all FFI object header that allocates on heap.
*/
typedef struct {
- // Ref counter goes first to align ABI with most intrusive ptr designs.
- // It is also likely more efficient as rc operations can be quite common
- // ABI note: Strong ref counter and weak ref counter can be packed into a
single 64-bit field
- // Hopefully in future being able to use 64bit atomic that avoids extra
reading of
- // weak counter during deletion.
- /*! \brief Strong reference counter of the object. */
- uint32_t strong_ref_count;
/*!
- * \brief Weak reference counter of the object, for compatiblity with
weak_ptr design.
+ * \brief Combined strong and weak reference counter of the object.
+ *
+ * Strong ref counter is packed into the lower 32 bits.
+ * Weak ref counter is packed into the upper 32 bits.
+ *
+ * It is equivalent to { uint32_t strong_ref_count, uint32_t weak_ref_count }
+ * in little-endian structure:
+ *
+ * - strong_ref_count: `combined_ref_count & 0xFFFFFFFF`
+ * - weak_ref_count: `(combined_ref_count >> 32) & 0xFFFFFFFF`
+ *
+ * Rationale: atomic ops on strong ref counter remains the same as +1/-1,
+ * this combined ref counter allows us to use u64 atomic once
+ * instead of a separate atomic read of weak counter during deletion.
+ *
+ * The ref counter goes first to align ABI with most intrusive ptr designs.
+ * It is also likely more efficient as rc operations can be quite common.
*/
- uint32_t weak_ref_count;
+ uint64_t combined_ref_count;
/*!
* \brief type index of the object.
* \note The type index of Object and Any are shared in FFI.
diff --git a/include/tvm/ffi/memory.h b/include/tvm/ffi/memory.h
index 1fa9d65..76c9003 100644
--- a/include/tvm/ffi/memory.h
+++ b/include/tvm/ffi/memory.h
@@ -66,8 +66,7 @@ class ObjAllocatorBase {
static_assert(std::is_base_of<Object, T>::value, "make can only be used to
create Object");
T* ptr = Handler::New(static_cast<Derived*>(this),
std::forward<Args>(args)...);
TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr);
- ffi_ptr->strong_ref_count = 1;
- ffi_ptr->weak_ref_count = 1;
+ ffi_ptr->combined_ref_count = kCombinedRefCountBothOne;
ffi_ptr->type_index = T::RuntimeTypeIndex();
ffi_ptr->deleter = Handler::Deleter();
return details::ObjectUnsafe::ObjectPtrFromOwned<T>(ptr);
@@ -88,8 +87,7 @@ class ObjAllocatorBase {
ArrayType* ptr =
Handler::New(static_cast<Derived*>(this), num_elems,
std::forward<Args>(args)...);
TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr);
- ffi_ptr->strong_ref_count = 1;
- ffi_ptr->weak_ref_count = 1;
+ ffi_ptr->combined_ref_count = kCombinedRefCountBothOne;
ffi_ptr->type_index = ArrayType::RuntimeTypeIndex();
ffi_ptr->deleter = Handler::Deleter();
return details::ObjectUnsafe::ObjectPtrFromOwned<ArrayType>(ptr);
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index e5f955c..6eac9a4 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -129,6 +129,15 @@ namespace details {
// unsafe operations related to object
struct ObjectUnsafe;
+/*! \brief One counter for weak reference. */
+constexpr uint64_t kCombinedRefCountWeakOne = static_cast<uint64_t>(1) << 32;
+/*! \brief One counter for strong reference. */
+constexpr uint64_t kCombinedRefCountStrongOne = 1;
+/*! \brief Both reference counts. */
+constexpr uint64_t kCombinedRefCountBothOne = kCombinedRefCountWeakOne |
kCombinedRefCountStrongOne;
+/*! \brief Mask to get the lower 32 bits of the combined reference count. */
+constexpr uint64_t kCombinedRefCountMaskUInt32 = (static_cast<uint64_t>(1) <<
32) - 1;
+
/*!
* Check if the type_index is an instance of TargetObjectType.
*
@@ -192,8 +201,7 @@ class Object {
public:
Object() {
- header_.strong_ref_count = 0;
- header_.weak_ref_count = 0;
+ header_.combined_ref_count = 0;
header_.deleter = nullptr;
}
/*!
@@ -247,12 +255,16 @@ class Object {
* \return The usage count of the cell.
* \note We use STL style naming to be consistent with known API in
shared_ptr.
*/
- int32_t use_count() const {
+ uint64_t use_count() const {
// only need relaxed load of counters
#ifdef _MSC_VER
- return (reinterpret_cast<const volatile
long*>(&header_.strong_ref_count))[0]; // NOLINT(*)
+ return ((reinterpret_cast<const volatile uint64_t*>(
+ &header_.combined_ref_count))[0] // NOLINT(*)
+ ) &
+ kCombinedRefCountMaskUInt32;
#else
- return __atomic_load_n(&(header_.strong_ref_count), __ATOMIC_RELAXED);
+ return __atomic_load_n(&(header_.combined_ref_count), __ATOMIC_RELAXED) &
+ kCombinedRefCountMaskUInt32;
#endif
}
@@ -290,13 +302,18 @@ class Object {
static int32_t _GetOrAllocRuntimeTypeIndex() { return
TypeIndex::kTVMFFIObject; }
private:
+ // exposing detailed constants to here
+ static constexpr uint64_t kCombinedRefCountMaskUInt32 =
details::kCombinedRefCountMaskUInt32;
+ static constexpr uint64_t kCombinedRefCountStrongOne =
details::kCombinedRefCountStrongOne;
+ static constexpr uint64_t kCombinedRefCountWeakOne =
details::kCombinedRefCountWeakOne;
+ static constexpr uint64_t kCombinedRefCountBothOne =
details::kCombinedRefCountBothOne;
/*! \brief increase strong reference count, the caller must already hold a
strong reference */
void IncRef() {
#ifdef _MSC_VER
- _InterlockedIncrement(
- reinterpret_cast<volatile long*>(&header_.strong_ref_count)); //
NOLINT(*)
+ _InterlockedIncrement64(
+ reinterpret_cast<volatile __int64*>(&header_.combined_ref_count)); //
NOLINT(*)
#else
- __atomic_fetch_add(&(header_.strong_ref_count), 1, __ATOMIC_RELAXED);
+ __atomic_fetch_add(&(header_.combined_ref_count), 1, __ATOMIC_RELAXED);
#endif
}
/*!
@@ -306,12 +323,12 @@ class Object {
*/
bool TryPromoteWeakPtr() {
#ifdef _MSC_VER
- uint32_t old_count =
- (reinterpret_cast<const volatile
long*>(&header_.strong_ref_count))[0]; // NOLINT(*)
- while (old_count > 0) {
- uint32_t new_count = old_count + 1;
- uint32_t old_count_loaded = _InterlockedCompareExchange(
- reinterpret_cast<volatile long*>(&header_.strong_ref_count),
new_count, old_count);
+ uint64_t old_count =
+ (reinterpret_cast<const volatile
__int64*>(&header_.combined_ref_count))[0]; // NOLINT(*)
+ while ((old_count & kCombinedRefCountMaskUInt32) != 0) {
+ uint64_t new_count = old_count + kCombinedRefCountStrongOne;
+ uint64_t old_count_loaded = _InterlockedCompareExchange64(
+ reinterpret_cast<volatile __int64*>(&header_.combined_ref_count),
new_count, old_count);
if (old_count == old_count_loaded) {
return true;
}
@@ -319,13 +336,13 @@ class Object {
}
return false;
#else
- uint32_t old_count = __atomic_load_n(&(header_.strong_ref_count),
__ATOMIC_RELAXED);
- while (old_count > 0) {
+ uint64_t old_count = __atomic_load_n(&(header_.combined_ref_count),
__ATOMIC_RELAXED);
+ while ((old_count & kCombinedRefCountMaskUInt32) != 0) {
// must do CAS to ensure that we are the only one that increases the
reference count
// avoid condition when two threads tries to promote weak to strong at
same time
// or when strong deletion happens between the load and the CAS
- uint32_t new_count = old_count + 1;
- if (__atomic_compare_exchange_n(&(header_.strong_ref_count), &old_count,
new_count, true,
+ uint64_t new_count = old_count + kCombinedRefCountStrongOne;
+ if (__atomic_compare_exchange_n(&(header_.combined_ref_count),
&old_count, new_count, true,
__ATOMIC_ACQ_REL, __ATOMIC_RELAXED)) {
return true;
}
@@ -337,9 +354,11 @@ class Object {
/*! \brief increase weak reference count */
void IncWeakRef() {
#ifdef _MSC_VER
- _InterlockedIncrement(reinterpret_cast<volatile
long*>(&header_.weak_ref_count)); // NOLINT(*)
+ _InlineInterlockedAdd64(
+ reinterpret_cast<volatile __int64*>(&header_.combined_ref_count), //
NOLINT(*)
+ kCombinedRefCountWeakOne);
#else
- __atomic_fetch_add(&(header_.weak_ref_count), 1, __ATOMIC_RELAXED);
+ __atomic_fetch_add(&(header_.combined_ref_count),
kCombinedRefCountWeakOne, __ATOMIC_RELAXED);
#endif
}
@@ -347,43 +366,62 @@ class Object {
void DecRef() {
#ifdef _MSC_VER
// use simpler impl in windows to ensure correctness
- if (_InterlockedDecrement(
//
- reinterpret_cast<volatile long*>(&header_.strong_ref_count)) == 0)
{ // NOLINT(*)
- // full barrrier is implicit in InterlockedDecrement
+ uint64_t count_before_sub =
+ _InterlockedDecrement64(
//
+ reinterpret_cast<volatile __int64*>(&header_.combined_ref_count)
// NOLINT(*)
+ ) +
+ 1;
+ if (count_before_sub == kCombinedRefCountBothOne) { // NOLINT(*)
+ // fast path: both reference counts will go to zero
+ if (header_.deleter != nullptr) {
+ // full barrrier is implicit in InterlockedDecrement
+ header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth);
+ }
+ } else if ((count_before_sub & kCombinedRefCountMaskUInt32) ==
kCombinedRefCountStrongOne) {
+ // strong reference count becomes zero, we need to first do strong
deletion
+ // then decrease weak reference count
+ // full barrrier is implicit in InterlockedAdd
if (header_.deleter != nullptr) {
header_.deleter(&(this->header_),
kTVMFFIObjectDeleterFlagBitMaskStrong);
}
- if (_InterlockedDecrement(
//
- reinterpret_cast<volatile long*>(&header_.weak_ref_count)) == 0)
{ // NOLINT(*)
+ // decrease weak reference count
+ if (_InlineInterlockedAdd64( //
+ reinterpret_cast<volatile __int64*>(&header_.combined_ref_count),
+ -kCombinedRefCountWeakOne) == 0) { // NOLINT(*)
if (header_.deleter != nullptr) {
+ // full barrrier is implicit in InterlockedAdd
header_.deleter(&(this->header_),
kTVMFFIObjectDeleterFlagBitMaskWeak);
}
}
}
#else
// first do a release, note we only need to acquire for deleter
- if (__atomic_fetch_sub(&(header_.strong_ref_count), 1, __ATOMIC_RELEASE)
== 1) {
- if (__atomic_load_n(&(header_.weak_ref_count), __ATOMIC_RELAXED) == 1) {
- // common case, we need to delete both the object and the memory block
- // only acquire when we need to call deleter
- __atomic_thread_fence(__ATOMIC_ACQUIRE);
- if (header_.deleter != nullptr) {
- // call deleter once
- header_.deleter(&(this->header_),
kTVMFFIObjectDeleterFlagBitMaskBoth);
- }
- } else {
- // Slower path: there is still a weak reference left
+ // optimization: we only need one atomic to tell the common case
+ // where both reference counts are zero
+ uint64_t count_before_sub =
__atomic_fetch_sub(&(header_.combined_ref_count),
+ kCombinedRefCountStrongOne,
__ATOMIC_RELEASE);
+ if (count_before_sub == kCombinedRefCountBothOne) {
+ // common case, we need to delete both the object and the memory block
+ // only acquire when we need to call deleter
+ __atomic_thread_fence(__ATOMIC_ACQUIRE);
+ if (header_.deleter != nullptr) {
+ // call deleter once
+ header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth);
+ }
+ } else if ((count_before_sub & kCombinedRefCountMaskUInt32) ==
kCombinedRefCountStrongOne) {
+ // strong count is already zero
+ // Slower path: there is still a weak reference left
+ __atomic_thread_fence(__ATOMIC_ACQUIRE);
+ // call destructor first, then decrease weak reference count
+ if (header_.deleter != nullptr) {
+ header_.deleter(&(this->header_),
kTVMFFIObjectDeleterFlagBitMaskStrong);
+ }
+ // now decrease weak reference count
+ if (__atomic_fetch_sub(&(header_.combined_ref_count),
kCombinedRefCountWeakOne,
+ __ATOMIC_RELEASE) == kCombinedRefCountWeakOne) {
__atomic_thread_fence(__ATOMIC_ACQUIRE);
- // call destructor first, then decrease weak reference count
if (header_.deleter != nullptr) {
- header_.deleter(&(this->header_),
kTVMFFIObjectDeleterFlagBitMaskStrong);
- }
- // now decrease weak reference count
- if (__atomic_fetch_sub(&(header_.weak_ref_count), 1, __ATOMIC_RELEASE)
== 1) {
- __atomic_thread_fence(__ATOMIC_ACQUIRE);
- if (header_.deleter != nullptr) {
- header_.deleter(&(this->header_),
kTVMFFIObjectDeleterFlagBitMaskWeak);
- }
+ header_.deleter(&(this->header_),
kTVMFFIObjectDeleterFlagBitMaskWeak);
}
}
}
@@ -393,15 +431,17 @@ class Object {
/*! \brief decrease weak reference count */
void DecWeakRef() {
#ifdef _MSC_VER
- if (_InterlockedDecrement(
//
- reinterpret_cast<volatile long*>(&header_.weak_ref_count)) == 0) {
// NOLINT(*)
+ if (_InlineInterlockedAdd64(
//
+ reinterpret_cast<volatile __int64*>(&header_.combined_ref_count),
// NOLINT(*)
+ -kCombinedRefCountWeakOne) == 0) {
if (header_.deleter != nullptr) {
header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak);
}
}
#else
// now decrease weak reference count
- if (__atomic_fetch_sub(&(header_.weak_ref_count), 1, __ATOMIC_RELEASE) ==
1) {
+ if (__atomic_fetch_sub(&(header_.combined_ref_count),
kCombinedRefCountWeakOne,
+ __ATOMIC_RELEASE) == kCombinedRefCountWeakOne) {
__atomic_thread_fence(__ATOMIC_ACQUIRE);
if (header_.deleter != nullptr) {
header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak);
diff --git a/pyproject.toml b/pyproject.toml
index 83a280b..8166de4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
[project]
name = "apache-tvm-ffi"
-version = "0.1.0b7"
+version = "0.1.0b8"
description = "tvm ffi"
authors = [{ name = "TVM FFI team" }]
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index ff532ea..a3ab73e 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -111,10 +111,9 @@ cdef extern from "tvm/ffi/c_api.h":
ctypedef void* TVMFFIObjectHandle
ctypedef struct TVMFFIObject:
- uint32_t strong_ref_count
- uint32_t weak_ref_count
+ uint64_t combined_ref_count
int32_t type_index
- int32_t __padding
+ uint32_t __padding
void (*deleter)(TVMFFIObject* self)
ctypedef struct TVMFFIAny: