This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 39d9b2b4 [CORE] Enable customized AnyHash/Equal in Object Type attr
(#451)
39d9b2b4 is described below
commit 39d9b2b400646be720e98f001353cc0d8d4b0234
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Feb 15 21:33:01 2026 -0500
[CORE] Enable customized AnyHash/Equal in Object Type attr (#451)
This PR enables customizable anyhash/equal by object type via type attr
column.
---
include/tvm/ffi/any.h | 128 +++++++++++++++++++++++++++++++++++++
include/tvm/ffi/endian.h | 1 -
include/tvm/ffi/error.h | 19 ++++++
include/tvm/ffi/function_details.h | 20 ------
src/ffi/container.cc | 2 +
tests/cpp/test_any.cc | 30 +++++++++
tests/cpp/testing_object.h | 42 +++++++++---
7 files changed, 212 insertions(+), 30 deletions(-)
diff --git a/include/tvm/ffi/any.h b/include/tvm/ffi/any.h
index 9adbd849..530d9aaa 100644
--- a/include/tvm/ffi/any.h
+++ b/include/tvm/ffi/any.h
@@ -601,6 +601,7 @@ struct AnyUnsafe : public ObjectUnsafe {
/*! \brief String-aware Any equal functor */
struct AnyHash {
+ public:
/*!
* \brief Calculate the hash code of an Any
* \param a The given Any
@@ -623,13 +624,77 @@ struct AnyHash {
return details::StableHashCombine(src.data_.type_index,
details::StableHashBytes(src_str->data, src_str->size));
} else {
+ if (src.data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
+ static const TVMFFITypeAttrColumn* custom_hash_column =
GetAnyHashTypeAttrColumn();
+ if (custom_hash_column != nullptr &&
+ static_cast<size_t>(src.data_.type_index) <
custom_hash_column->size) {
+ const TVMFFIAny& custom_any_hash =
custom_hash_column->data[src.data_.type_index];
+ if (custom_any_hash.type_index != TypeIndex::kTVMFFINone) {
+ return details::StableHashCombine(src.data_.type_index,
+
CallCustomAnyHash(custom_any_hash, src));
+ }
+ }
+ }
return details::StableHashCombine(src.data_.type_index,
src.data_.v_uint64);
}
}
+
+ private:
+ /*!
+ * \brief Get the type attribute column for any hash.
+ * \return The type attribute column for any hash.
+ */
+ static const TVMFFITypeAttrColumn* GetAnyHashTypeAttrColumn() {
+ constexpr const char* kAttrName = "__any_hash__";
+ TVMFFIByteArray attr_name =
+ TVMFFIByteArray{kAttrName, std::char_traits<char>::length(kAttrName)};
+ return TVMFFIGetTypeAttrColumn(&attr_name);
+ }
+
+ /*!
+ * \brief Call the custom any hash function registered in type attribute
column.
+ * \param custom_any_hash The custom any hash function object or function
pointer.
+ * \param src The source Any.
+ * \return The hash value.
+ */
+ static uint64_t CallCustomAnyHash(const TVMFFIAny& custom_any_hash, const
Any& src) {
+ // NOTE: we explicitly use low-level ABI here since we do not want to have
dep on function.h
+ // it also keeps the logic simple
+ if (custom_any_hash.type_index == TypeIndex::kTVMFFIOpaquePtr) {
+ // we allow this attribute to be a function pointer for fast path
+ using FCustomAnyHashPtr = int64_t (*)(const Any&);
+ FCustomAnyHashPtr hash_func =
reinterpret_cast<FCustomAnyHashPtr>(custom_any_hash.v_ptr);
+ TVM_FFI_ICHECK_NOTNULL(hash_func);
+ return static_cast<uint64_t>(
+ (*hash_func)(src)); // NOLINT(clang-analyzer-core.CallAndMessage)
+ } else {
+ // alternatively it can be a ffi.Function object.
+ TVM_FFI_ICHECK_EQ(custom_any_hash.type_index,
TypeIndex::kTVMFFIFunction);
+ TVMFFIAny arg = src.data_;
+ TVMFFIAny result;
+ result.type_index = TypeIndex::kTVMFFINone;
+ result.zero_padding = 0;
+ result.v_int64 = 0;
+ TVMFFIFunctionCell* func_cell =
TVMFFIFunctionGetCellPtr(custom_any_hash.v_obj);
+ if (func_cell->cpp_call != nullptr) {
+ // Fast path: invoke C++ ABI call directly when available.
+ using FCppCall = void (*)(const void*, const TVMFFIAny*, int32_t,
TVMFFIAny*);
+ reinterpret_cast<FCppCall>(func_cell->cpp_call)(custom_any_hash.v_obj,
&arg, 1, &result);
+ } else {
+ if (func_cell->safe_call(custom_any_hash.v_obj, &arg, 1, &result) !=
0) {
+ throw details::MoveFromSafeCallRaised();
+ }
+ }
+ Any result_any = details::AnyUnsafe::MoveTVMFFIAnyToAny(&result);
+ TVM_FFI_ICHECK_EQ(result_any.type_index(), TypeIndex::kTVMFFIInt);
+ return static_cast<uint64_t>(result_any.data_.v_int64);
+ }
+ }
};
/*! \brief String-aware Any hash functor */
struct AnyEqual {
+ public:
/*!
* \brief Check if the two Any are equal
* \param lhs left operand.
@@ -657,6 +722,16 @@ struct AnyEqual {
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const
details::BytesObjBase*>(rhs);
return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size,
rhs_str->size);
}
+ if (lhs.data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
+ static const TVMFFITypeAttrColumn* custom_equal_column =
GetAnyEqualTypeAttrColumn();
+ if (custom_equal_column != nullptr &&
+ static_cast<size_t>(lhs.data_.type_index) <
custom_equal_column->size) {
+ const TVMFFIAny& custom_any_equal =
custom_equal_column->data[lhs.data_.type_index];
+ if (custom_any_equal.type_index != TypeIndex::kTVMFFINone) {
+ return CallCustomAnyEqual(custom_any_equal, lhs, rhs);
+ }
+ }
+ }
return false;
} else {
// type_index mismatch, if index is not string, return false
@@ -692,6 +767,59 @@ struct AnyEqual {
return false;
}
}
+
+ private:
+ /*!
+ * \brief Get the type attribute column for any equal.
+ * \return The type attribute column for any equal.
+ */
+ static const TVMFFITypeAttrColumn* GetAnyEqualTypeAttrColumn() {
+ constexpr const char* kAttrName = "__any_equal__";
+ TVMFFIByteArray attr_name =
+ TVMFFIByteArray{kAttrName, std::char_traits<char>::length(kAttrName)};
+ return TVMFFIGetTypeAttrColumn(&attr_name);
+ }
+
+ /*!
+ * \brief Call the custom any equal function registered in type attribute
column.
+ * \param custom_any_equal The custom any equal function object or function
pointer.
+ * \param lhs The left-hand side Any.
+ * \param rhs The right-hand side Any.
+ * \return The equality result.
+ */
+ static bool CallCustomAnyEqual(const TVMFFIAny& custom_any_equal, const Any&
lhs,
+ const Any& rhs) {
+ // NOTE: we explicitly use low-level ABI here since we do not want to have
dep on function.h
+ // it also keeps the logic simple
+ if (custom_any_equal.type_index == TypeIndex::kTVMFFIOpaquePtr) {
+ // we allow this attribute to be a function pointer for fast path
+ using FCustomAnyEqualPtr = bool (*)(const Any&, const Any&);
+ FCustomAnyEqualPtr equal_func =
reinterpret_cast<FCustomAnyEqualPtr>(custom_any_equal.v_ptr);
+ TVM_FFI_ICHECK_NOTNULL(equal_func);
+ return (*equal_func)(lhs, rhs); //
NOLINT(clang-analyzer-core.CallAndMessage)
+ } else {
+ // alternatively it can be a ffi.Function object.
+ TVM_FFI_ICHECK_EQ(custom_any_equal.type_index,
TypeIndex::kTVMFFIFunction);
+ TVMFFIAny args[2] = {lhs.data_, rhs.data_};
+ TVMFFIAny result;
+ result.type_index = TypeIndex::kTVMFFINone;
+ result.zero_padding = 0;
+ result.v_int64 = 0;
+ TVMFFIFunctionCell* func_cell =
TVMFFIFunctionGetCellPtr(custom_any_equal.v_obj);
+ if (func_cell->cpp_call != nullptr) {
+ // Fast path: invoke C++ ABI call directly when available.
+ using FCppCall = void (*)(const void*, const TVMFFIAny*, int32_t,
TVMFFIAny*);
+
reinterpret_cast<FCppCall>(func_cell->cpp_call)(custom_any_equal.v_obj, args,
2, &result);
+ } else {
+ if (func_cell->safe_call(custom_any_equal.v_obj, args, 2, &result) !=
0) {
+ throw details::MoveFromSafeCallRaised();
+ }
+ }
+ Any result_any = details::AnyUnsafe::MoveTVMFFIAnyToAny(&result);
+ TVM_FFI_ICHECK(result_any.type_index() == TypeIndex::kTVMFFIBool);
+ return result_any.data_.v_int64 != 0;
+ }
+ }
};
} // namespace ffi
diff --git a/include/tvm/ffi/endian.h b/include/tvm/ffi/endian.h
index 10639bea..c068c3d5 100644
--- a/include/tvm/ffi/endian.h
+++ b/include/tvm/ffi/endian.h
@@ -1,4 +1,3 @@
-
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
diff --git a/include/tvm/ffi/error.h b/include/tvm/ffi/error.h
index f1cfa213..3ecfbb72 100644
--- a/include/tvm/ffi/error.h
+++ b/include/tvm/ffi/error.h
@@ -317,6 +317,25 @@ inline Error EnvErrorAlreadySet() { return
Error("EnvErrorAlreadySet", "", "");
namespace details {
+/*!
+ * \brief Move the last raised safe-call error from TLS.
+ * \return The raised error object.
+ */
+TVM_FFI_INLINE Error MoveFromSafeCallRaised() {
+ TVMFFIObjectHandle handle;
+ TVMFFIErrorMoveFromRaised(&handle);
+ return details::ObjectUnsafe::ObjectRefFromObjectPtr<Error>(
+
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle)));
+}
+
+/*!
+ * \brief Set a raised safe-call error into TLS.
+ * \param error The error to be raised.
+ */
+TVM_FFI_INLINE void SetSafeCallRaised(const Error& error) {
+
TVMFFIErrorSetRaised(details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(error));
+}
+
class ErrorBuilder {
public:
explicit ErrorBuilder(std::string kind, std::string backtrace, bool
log_before_throw)
diff --git a/include/tvm/ffi/function_details.h
b/include/tvm/ffi/function_details.h
index d04c697f..cf33a79c 100644
--- a/include/tvm/ffi/function_details.h
+++ b/include/tvm/ffi/function_details.h
@@ -231,26 +231,6 @@ TVM_FFI_INLINE void
unpack_call(std::index_sequence<Is...>, const std::string* o
// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks
}
-/*!
- * \brief Move the safe call raised error to the caller
- * \return The error
- */
-TVM_FFI_INLINE static Error MoveFromSafeCallRaised() {
- TVMFFIObjectHandle handle;
- TVMFFIErrorMoveFromRaised(&handle);
- // handle is owned by caller
- return details::ObjectUnsafe::ObjectRefFromObjectPtr<Error>(
-
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle)));
-}
-
-/*!
- * \brief Set the safe call raised error
- * \param error The error
- */
-TVM_FFI_INLINE static void SetSafeCallRaised(const Error& error) {
-
TVMFFIErrorSetRaised(details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(error));
-}
-
template <typename T>
struct TypeSchemaImpl {
static std::string v() {
diff --git a/src/ffi/container.cc b/src/ffi/container.cc
index 6ebc7c07..570fd755 100644
--- a/src/ffi/container.cc
+++ b/src/ffi/container.cc
@@ -63,6 +63,8 @@ ObjectRef GetMissingObject() {
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
+ refl::EnsureTypeAttrColumn("__any_hash__");
+ refl::EnsureTypeAttrColumn("__any_equal__");
refl::GlobalDef()
.def_packed("ffi.Array",
[](ffi::PackedArgs args, Any* ret) {
diff --git a/tests/cpp/test_any.cc b/tests/cpp/test_any.cc
index c29ad5c8..2686b7a8 100644
--- a/tests/cpp/test_any.cc
+++ b/tests/cpp/test_any.cc
@@ -428,4 +428,34 @@ TEST(Any, AnyEqualHash) {
EXPECT_EQ(AnyHash()(c), AnyHash()(d));
}
+TEST(Any, CustomAnyHash) {
+ // Covers the OpaquePtr custom hash branch.
+ Any int_src = TInt(7);
+ uint64_t int_expected = details::StableHashCombine(
+ int_src.type_index(),
static_cast<uint64_t>(TInt::CustomAnyHash(int_src)));
+ EXPECT_EQ(AnyHash()(int_src), int_expected);
+
+ // Covers the ffi.Function custom hash branch.
+ Any float_src = TFloat(3.5);
+ uint64_t float_expected = details::StableHashCombine(
+ float_src.type_index(),
static_cast<uint64_t>(TFloat::CustomAnyHash(float_src)));
+ EXPECT_EQ(AnyHash()(float_src), float_expected);
+}
+
+TEST(Any, CustomAnyEqual) {
+ // Covers the OpaquePtr custom equal branch.
+ Any int_lhs = TInt(7);
+ Any int_rhs = TInt(7);
+ Any int_diff = TInt(8);
+ EXPECT_TRUE(AnyEqual()(int_lhs, int_rhs));
+ EXPECT_FALSE(AnyEqual()(int_lhs, int_diff));
+
+ // Covers the ffi.Function custom equal branch.
+ Any float_lhs = TFloat(3.5);
+ Any float_rhs = TFloat(3.5);
+ Any float_diff = TFloat(4.5);
+ EXPECT_TRUE(AnyEqual()(float_lhs, float_rhs));
+ EXPECT_FALSE(AnyEqual()(float_lhs, float_diff));
+}
+
} // namespace
diff --git a/tests/cpp/testing_object.h b/tests/cpp/testing_object.h
index 318aa116..b3ce5504 100644
--- a/tests/cpp/testing_object.h
+++ b/tests/cpp/testing_object.h
@@ -20,6 +20,7 @@
#ifndef TVM_FFI_TESTING_OBJECT_H_
#define TVM_FFI_TESTING_OBJECT_H_
+#include <tvm/ffi/any.h>
#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/memory.h>
@@ -74,6 +75,14 @@ class TInt : public TNumber {
static TInt StaticAdd(TInt lhs, TInt rhs) { return TInt(lhs->value +
rhs->value); }
+ static int64_t CustomAnyHash(const Any& src) {
+ return static_cast<int64_t>(src.cast<TInt>()->value + 1024);
+ }
+
+ static bool CustomAnyEqual(const Any& lhs, const Any& rhs) {
+ return lhs.cast<TInt>()->value == rhs.cast<TInt>()->value;
+ }
+
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TInt, TNumber, TIntObj);
};
@@ -85,7 +94,9 @@ inline void TIntObj::RegisterReflection() {
// define extra type attributes
refl::TypeAttrDef<TIntObj>()
.def("test.GetValue", &TIntObj::GetValue)
- .attr("test.size", sizeof(TIntObj));
+ .attr("test.size", sizeof(TIntObj))
+ .attr("__any_hash__", reinterpret_cast<void*>(&TInt::CustomAnyHash))
+ .attr("__any_equal__", reinterpret_cast<void*>(&TInt::CustomAnyEqual));
// custom json serialization
refl::TypeAttrDef<TIntObj>()
.def("__data_to_json__",
@@ -105,14 +116,7 @@ class TFloatObj : public TNumberObj {
double Add(double other) const { return value + other; }
- static void RegisterReflection() {
- namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<TFloatObj>()
- .def_ro("value", &TFloatObj::value, "float value field",
refl::DefaultValue(10.0))
- .def("sub",
- [](const TFloatObj* self, double other) -> double { return
self->value - other; })
- .def("add", &TFloatObj::Add, "add method");
- }
+ static void RegisterReflection();
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.Float", TFloatObj, TNumberObj);
};
@@ -121,9 +125,29 @@ class TFloat : public TNumber {
public:
explicit TFloat(double value) { data_ = make_object<TFloatObj>(value); }
+ static uint64_t CustomAnyHash(const Any& src) {
+ double value = src.cast<TFloat>()->value;
+ return static_cast<int64_t>(value * 10 + 2048);
+ }
+
+ static bool CustomAnyEqual(const Any& lhs, const Any& rhs) {
+ return lhs.cast<TFloat>()->value == rhs.cast<TFloat>()->value;
+ }
+
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TFloat, TNumber, TFloatObj);
};
+inline void TFloatObj::RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<TFloatObj>()
+ .def_ro("value", &TFloatObj::value, "float value field",
refl::DefaultValue(10.0))
+ .def("sub", [](const TFloatObj* self, double other) -> double { return
self->value - other; })
+ .def("add", &TFloatObj::Add, "add method");
+ refl::TypeAttrDef<TFloatObj>()
+ .def("__any_hash__", &TFloat::CustomAnyHash)
+ .def("__any_equal__", &TFloat::CustomAnyEqual);
+}
+
class TPrimExprObj : public Object {
public:
std::string dtype;