This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 39d9b2b4 [CORE] Enable customized AnyHash/Equal in Object Type attr 
(#451)
39d9b2b4 is described below

commit 39d9b2b400646be720e98f001353cc0d8d4b0234
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Feb 15 21:33:01 2026 -0500

    [CORE] Enable customized AnyHash/Equal in Object Type attr (#451)
    
    This PR enables customizable anyhash/equal by object type via type attr
    column.
---
 include/tvm/ffi/any.h              | 128 +++++++++++++++++++++++++++++++++++++
 include/tvm/ffi/endian.h           |   1 -
 include/tvm/ffi/error.h            |  19 ++++++
 include/tvm/ffi/function_details.h |  20 ------
 src/ffi/container.cc               |   2 +
 tests/cpp/test_any.cc              |  30 +++++++++
 tests/cpp/testing_object.h         |  42 +++++++++---
 7 files changed, 212 insertions(+), 30 deletions(-)

diff --git a/include/tvm/ffi/any.h b/include/tvm/ffi/any.h
index 9adbd849..530d9aaa 100644
--- a/include/tvm/ffi/any.h
+++ b/include/tvm/ffi/any.h
@@ -601,6 +601,7 @@ struct AnyUnsafe : public ObjectUnsafe {
 
 /*! \brief String-aware Any equal functor */
 struct AnyHash {
+ public:
   /*!
    * \brief Calculate the hash code of an Any
    * \param a The given Any
@@ -623,13 +624,77 @@ struct AnyHash {
       return details::StableHashCombine(src.data_.type_index,
                                         
details::StableHashBytes(src_str->data, src_str->size));
     } else {
+      if (src.data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
+        static const TVMFFITypeAttrColumn* custom_hash_column = 
GetAnyHashTypeAttrColumn();
+        if (custom_hash_column != nullptr &&
+            static_cast<size_t>(src.data_.type_index) < 
custom_hash_column->size) {
+          const TVMFFIAny& custom_any_hash = 
custom_hash_column->data[src.data_.type_index];
+          if (custom_any_hash.type_index != TypeIndex::kTVMFFINone) {
+            return details::StableHashCombine(src.data_.type_index,
+                                              
CallCustomAnyHash(custom_any_hash, src));
+          }
+        }
+      }
       return details::StableHashCombine(src.data_.type_index, 
src.data_.v_uint64);
     }
   }
+
+ private:
+  /*!
+   * \brief Get the type attribute column for any hash.
+   * \return The type attribute column for any hash.
+   */
+  static const TVMFFITypeAttrColumn* GetAnyHashTypeAttrColumn() {
+    constexpr const char* kAttrName = "__any_hash__";
+    TVMFFIByteArray attr_name =
+        TVMFFIByteArray{kAttrName, std::char_traits<char>::length(kAttrName)};
+    return TVMFFIGetTypeAttrColumn(&attr_name);
+  }
+
+  /*!
+   * \brief Call the custom any hash function registered in type attribute 
column.
+   * \param custom_any_hash The custom any hash function object or function 
pointer.
+   * \param src The source Any.
+   * \return The hash value.
+   */
+  static uint64_t CallCustomAnyHash(const TVMFFIAny& custom_any_hash, const 
Any& src) {
+    // NOTE: we explicitly use low-level ABI here since we do not want to have 
dep on function.h
+    // it also keeps the logic simple
+    if (custom_any_hash.type_index == TypeIndex::kTVMFFIOpaquePtr) {
+      // we allow this attribute to be a function pointer for fast path
+      using FCustomAnyHashPtr = int64_t (*)(const Any&);
+      FCustomAnyHashPtr hash_func = 
reinterpret_cast<FCustomAnyHashPtr>(custom_any_hash.v_ptr);
+      TVM_FFI_ICHECK_NOTNULL(hash_func);
+      return static_cast<uint64_t>(
+          (*hash_func)(src));  // NOLINT(clang-analyzer-core.CallAndMessage)
+    } else {
+      // alternatively it can be a ffi.Function object.
+      TVM_FFI_ICHECK_EQ(custom_any_hash.type_index, 
TypeIndex::kTVMFFIFunction);
+      TVMFFIAny arg = src.data_;
+      TVMFFIAny result;
+      result.type_index = TypeIndex::kTVMFFINone;
+      result.zero_padding = 0;
+      result.v_int64 = 0;
+      TVMFFIFunctionCell* func_cell = 
TVMFFIFunctionGetCellPtr(custom_any_hash.v_obj);
+      if (func_cell->cpp_call != nullptr) {
+        // Fast path: invoke C++ ABI call directly when available.
+        using FCppCall = void (*)(const void*, const TVMFFIAny*, int32_t, 
TVMFFIAny*);
+        reinterpret_cast<FCppCall>(func_cell->cpp_call)(custom_any_hash.v_obj, 
&arg, 1, &result);
+      } else {
+        if (func_cell->safe_call(custom_any_hash.v_obj, &arg, 1, &result) != 
0) {
+          throw details::MoveFromSafeCallRaised();
+        }
+      }
+      Any result_any = details::AnyUnsafe::MoveTVMFFIAnyToAny(&result);
+      TVM_FFI_ICHECK_EQ(result_any.type_index(), TypeIndex::kTVMFFIInt);
+      return static_cast<uint64_t>(result_any.data_.v_int64);
+    }
+  }
 };
 
 /*! \brief String-aware Any hash functor */
 struct AnyEqual {
+ public:
   /*!
    * \brief Check if the two Any are equal
    * \param lhs left operand.
@@ -657,6 +722,16 @@ struct AnyEqual {
             details::AnyUnsafe::CopyFromAnyViewAfterCheck<const 
details::BytesObjBase*>(rhs);
         return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, 
rhs_str->size);
       }
+      if (lhs.data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) {
+        static const TVMFFITypeAttrColumn* custom_equal_column = 
GetAnyEqualTypeAttrColumn();
+        if (custom_equal_column != nullptr &&
+            static_cast<size_t>(lhs.data_.type_index) < 
custom_equal_column->size) {
+          const TVMFFIAny& custom_any_equal = 
custom_equal_column->data[lhs.data_.type_index];
+          if (custom_any_equal.type_index != TypeIndex::kTVMFFINone) {
+            return CallCustomAnyEqual(custom_any_equal, lhs, rhs);
+          }
+        }
+      }
       return false;
     } else {
       // type_index mismatch, if index is not string, return false
@@ -692,6 +767,59 @@ struct AnyEqual {
       return false;
     }
   }
+
+ private:
+  /*!
+   * \brief Get the type attribute column for any equal.
+   * \return The type attribute column for any equal.
+   */
+  static const TVMFFITypeAttrColumn* GetAnyEqualTypeAttrColumn() {
+    constexpr const char* kAttrName = "__any_equal__";
+    TVMFFIByteArray attr_name =
+        TVMFFIByteArray{kAttrName, std::char_traits<char>::length(kAttrName)};
+    return TVMFFIGetTypeAttrColumn(&attr_name);
+  }
+
+  /*!
+   * \brief Call the custom any equal function registered in type attribute 
column.
+   * \param custom_any_equal The custom any equal function object or function 
pointer.
+   * \param lhs The left-hand side Any.
+   * \param rhs The right-hand side Any.
+   * \return The equality result.
+   */
+  static bool CallCustomAnyEqual(const TVMFFIAny& custom_any_equal, const Any& 
lhs,
+                                 const Any& rhs) {
+    // NOTE: we explicitly use low-level ABI here since we do not want to have 
dep on function.h
+    // it also keeps the logic simple
+    if (custom_any_equal.type_index == TypeIndex::kTVMFFIOpaquePtr) {
+      // we allow this attribute to be a function pointer for fast path
+      using FCustomAnyEqualPtr = bool (*)(const Any&, const Any&);
+      FCustomAnyEqualPtr equal_func = 
reinterpret_cast<FCustomAnyEqualPtr>(custom_any_equal.v_ptr);
+      TVM_FFI_ICHECK_NOTNULL(equal_func);
+      return (*equal_func)(lhs, rhs);  // 
NOLINT(clang-analyzer-core.CallAndMessage)
+    } else {
+      // alternatively it can be a ffi.Function object.
+      TVM_FFI_ICHECK_EQ(custom_any_equal.type_index, 
TypeIndex::kTVMFFIFunction);
+      TVMFFIAny args[2] = {lhs.data_, rhs.data_};
+      TVMFFIAny result;
+      result.type_index = TypeIndex::kTVMFFINone;
+      result.zero_padding = 0;
+      result.v_int64 = 0;
+      TVMFFIFunctionCell* func_cell = 
TVMFFIFunctionGetCellPtr(custom_any_equal.v_obj);
+      if (func_cell->cpp_call != nullptr) {
+        // Fast path: invoke C++ ABI call directly when available.
+        using FCppCall = void (*)(const void*, const TVMFFIAny*, int32_t, 
TVMFFIAny*);
+        
reinterpret_cast<FCppCall>(func_cell->cpp_call)(custom_any_equal.v_obj, args, 
2, &result);
+      } else {
+        if (func_cell->safe_call(custom_any_equal.v_obj, args, 2, &result) != 
0) {
+          throw details::MoveFromSafeCallRaised();
+        }
+      }
+      Any result_any = details::AnyUnsafe::MoveTVMFFIAnyToAny(&result);
+      TVM_FFI_ICHECK(result_any.type_index() == TypeIndex::kTVMFFIBool);
+      return result_any.data_.v_int64 != 0;
+    }
+  }
 };
 }  // namespace ffi
 
diff --git a/include/tvm/ffi/endian.h b/include/tvm/ffi/endian.h
index 10639bea..c068c3d5 100644
--- a/include/tvm/ffi/endian.h
+++ b/include/tvm/ffi/endian.h
@@ -1,4 +1,3 @@
-
 /*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
diff --git a/include/tvm/ffi/error.h b/include/tvm/ffi/error.h
index f1cfa213..3ecfbb72 100644
--- a/include/tvm/ffi/error.h
+++ b/include/tvm/ffi/error.h
@@ -317,6 +317,25 @@ inline Error EnvErrorAlreadySet() { return 
Error("EnvErrorAlreadySet", "", "");
 
 namespace details {
 
+/*!
+ * \brief Move the last raised safe-call error from TLS.
+ * \return The raised error object.
+ */
+TVM_FFI_INLINE Error MoveFromSafeCallRaised() {
+  TVMFFIObjectHandle handle;
+  TVMFFIErrorMoveFromRaised(&handle);
+  return details::ObjectUnsafe::ObjectRefFromObjectPtr<Error>(
+      
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle)));
+}
+
+/*!
+ * \brief Set a raised safe-call error into TLS.
+ * \param error The error to be raised.
+ */
+TVM_FFI_INLINE void SetSafeCallRaised(const Error& error) {
+  
TVMFFIErrorSetRaised(details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(error));
+}
+
 class ErrorBuilder {
  public:
   explicit ErrorBuilder(std::string kind, std::string backtrace, bool 
log_before_throw)
diff --git a/include/tvm/ffi/function_details.h 
b/include/tvm/ffi/function_details.h
index d04c697f..cf33a79c 100644
--- a/include/tvm/ffi/function_details.h
+++ b/include/tvm/ffi/function_details.h
@@ -231,26 +231,6 @@ TVM_FFI_INLINE void 
unpack_call(std::index_sequence<Is...>, const std::string* o
   // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks
 }
 
-/*!
- * \brief Move the safe call raised error to the caller
- * \return The error
- */
-TVM_FFI_INLINE static Error MoveFromSafeCallRaised() {
-  TVMFFIObjectHandle handle;
-  TVMFFIErrorMoveFromRaised(&handle);
-  // handle is owned by caller
-  return details::ObjectUnsafe::ObjectRefFromObjectPtr<Error>(
-      
details::ObjectUnsafe::ObjectPtrFromOwned<Object>(static_cast<TVMFFIObject*>(handle)));
-}
-
-/*!
- * \brief Set the safe call raised error
- * \param error The error
- */
-TVM_FFI_INLINE static void SetSafeCallRaised(const Error& error) {
-  
TVMFFIErrorSetRaised(details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(error));
-}
-
 template <typename T>
 struct TypeSchemaImpl {
   static std::string v() {
diff --git a/src/ffi/container.cc b/src/ffi/container.cc
index 6ebc7c07..570fd755 100644
--- a/src/ffi/container.cc
+++ b/src/ffi/container.cc
@@ -63,6 +63,8 @@ ObjectRef GetMissingObject() {
 
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
+  refl::EnsureTypeAttrColumn("__any_hash__");
+  refl::EnsureTypeAttrColumn("__any_equal__");
   refl::GlobalDef()
       .def_packed("ffi.Array",
                   [](ffi::PackedArgs args, Any* ret) {
diff --git a/tests/cpp/test_any.cc b/tests/cpp/test_any.cc
index c29ad5c8..2686b7a8 100644
--- a/tests/cpp/test_any.cc
+++ b/tests/cpp/test_any.cc
@@ -428,4 +428,34 @@ TEST(Any, AnyEqualHash) {
   EXPECT_EQ(AnyHash()(c), AnyHash()(d));
 }
 
+TEST(Any, CustomAnyHash) {
+  // Covers the OpaquePtr custom hash branch.
+  Any int_src = TInt(7);
+  uint64_t int_expected = details::StableHashCombine(
+      int_src.type_index(), 
static_cast<uint64_t>(TInt::CustomAnyHash(int_src)));
+  EXPECT_EQ(AnyHash()(int_src), int_expected);
+
+  // Covers the ffi.Function custom hash branch.
+  Any float_src = TFloat(3.5);
+  uint64_t float_expected = details::StableHashCombine(
+      float_src.type_index(), 
static_cast<uint64_t>(TFloat::CustomAnyHash(float_src)));
+  EXPECT_EQ(AnyHash()(float_src), float_expected);
+}
+
+TEST(Any, CustomAnyEqual) {
+  // Covers the OpaquePtr custom equal branch.
+  Any int_lhs = TInt(7);
+  Any int_rhs = TInt(7);
+  Any int_diff = TInt(8);
+  EXPECT_TRUE(AnyEqual()(int_lhs, int_rhs));
+  EXPECT_FALSE(AnyEqual()(int_lhs, int_diff));
+
+  // Covers the ffi.Function custom equal branch.
+  Any float_lhs = TFloat(3.5);
+  Any float_rhs = TFloat(3.5);
+  Any float_diff = TFloat(4.5);
+  EXPECT_TRUE(AnyEqual()(float_lhs, float_rhs));
+  EXPECT_FALSE(AnyEqual()(float_lhs, float_diff));
+}
+
 }  // namespace
diff --git a/tests/cpp/testing_object.h b/tests/cpp/testing_object.h
index 318aa116..b3ce5504 100644
--- a/tests/cpp/testing_object.h
+++ b/tests/cpp/testing_object.h
@@ -20,6 +20,7 @@
 #ifndef TVM_FFI_TESTING_OBJECT_H_
 #define TVM_FFI_TESTING_OBJECT_H_
 
+#include <tvm/ffi/any.h>
 #include <tvm/ffi/container/array.h>
 #include <tvm/ffi/container/map.h>
 #include <tvm/ffi/memory.h>
@@ -74,6 +75,14 @@ class TInt : public TNumber {
 
   static TInt StaticAdd(TInt lhs, TInt rhs) { return TInt(lhs->value + 
rhs->value); }
 
+  static int64_t CustomAnyHash(const Any& src) {
+    return static_cast<int64_t>(src.cast<TInt>()->value + 1024);
+  }
+
+  static bool CustomAnyEqual(const Any& lhs, const Any& rhs) {
+    return lhs.cast<TInt>()->value == rhs.cast<TInt>()->value;
+  }
+
   TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TInt, TNumber, TIntObj);
 };
 
@@ -85,7 +94,9 @@ inline void TIntObj::RegisterReflection() {
   // define extra type attributes
   refl::TypeAttrDef<TIntObj>()
       .def("test.GetValue", &TIntObj::GetValue)
-      .attr("test.size", sizeof(TIntObj));
+      .attr("test.size", sizeof(TIntObj))
+      .attr("__any_hash__", reinterpret_cast<void*>(&TInt::CustomAnyHash))
+      .attr("__any_equal__", reinterpret_cast<void*>(&TInt::CustomAnyEqual));
   // custom json serialization
   refl::TypeAttrDef<TIntObj>()
       .def("__data_to_json__",
@@ -105,14 +116,7 @@ class TFloatObj : public TNumberObj {
 
   double Add(double other) const { return value + other; }
 
-  static void RegisterReflection() {
-    namespace refl = tvm::ffi::reflection;
-    refl::ObjectDef<TFloatObj>()
-        .def_ro("value", &TFloatObj::value, "float value field", 
refl::DefaultValue(10.0))
-        .def("sub",
-             [](const TFloatObj* self, double other) -> double { return 
self->value - other; })
-        .def("add", &TFloatObj::Add, "add method");
-  }
+  static void RegisterReflection();
 
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.Float", TFloatObj, TNumberObj);
 };
@@ -121,9 +125,29 @@ class TFloat : public TNumber {
  public:
   explicit TFloat(double value) { data_ = make_object<TFloatObj>(value); }
 
+  static uint64_t CustomAnyHash(const Any& src) {
+    double value = src.cast<TFloat>()->value;
+    return static_cast<int64_t>(value * 10 + 2048);
+  }
+
+  static bool CustomAnyEqual(const Any& lhs, const Any& rhs) {
+    return lhs.cast<TFloat>()->value == rhs.cast<TFloat>()->value;
+  }
+
   TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TFloat, TNumber, TFloatObj);
 };
 
+inline void TFloatObj::RegisterReflection() {
+  namespace refl = tvm::ffi::reflection;
+  refl::ObjectDef<TFloatObj>()
+      .def_ro("value", &TFloatObj::value, "float value field", 
refl::DefaultValue(10.0))
+      .def("sub", [](const TFloatObj* self, double other) -> double { return 
self->value - other; })
+      .def("add", &TFloatObj::Add, "add method");
+  refl::TypeAttrDef<TFloatObj>()
+      .def("__any_hash__", &TFloat::CustomAnyHash)
+      .def("__any_equal__", &TFloat::CustomAnyEqual);
+}
+
 class TPrimExprObj : public Object {
  public:
   std::string dtype;

Reply via email to