This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch sequal-upgrade
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/sequal-upgrade by this push:
     new 54aea761a7 streamline custom override
54aea761a7 is described below

commit 54aea761a70ff638dafa2b0c3e95a6e85e4c8f96
Author: tqchen <[email protected]>
AuthorDate: Sat Jul 26 16:33:10 2025 -0400

    streamline custom override
---
 ffi/include/tvm/ffi/c_api.h                | 21 --------------------
 ffi/src/ffi/reflection/structural_equal.cc |  8 +++-----
 ffi/src/ffi/reflection/structural_hash.cc  | 31 +++++++++++++++---------------
 include/tvm/ir/expr.h                      | 15 +++++++++++++++
 include/tvm/ir/module.h                    |  6 +++---
 include/tvm/tir/function.h                 | 25 ++++++++++++++++++++++++
 src/ir/module.cc                           | 20 +++++++++----------
 7 files changed, 70 insertions(+), 56 deletions(-)

diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index e2de610a5d..60743b82c6 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -424,27 +424,6 @@ typedef enum {
    * is only an unique copy of each value.
    */
   kTVMFFISEqHashKindUniqueInstance = 5,
-  /*!
-   * \brief provide custom __s_equal__ and __s_hash__ functions through 
TypeAttrColumn.
-   *
-   * The function signatures are(defined via ffi::Function)
-   *
-   * \code
-   * bool __s_equal__(
-   *   ObjectRefType self, ObjectRefType other,
-   *   ffi::TypedFunction<bool(AnyView, AnyView, bool def_region, string 
field_name)> cmp,
-   * );
-   *
-   * uint64_t __s_hash__(
-   *   ObjectRefType self, uint64_t type_key_hash,
-   *   ffi::TypedFunction<uint64_t(AnyView, bool def_region)> hash
-   * );
-   * \endcode
-   *
-   * Where the extra string field in cmp is the name of the field that is 
being compared.
-   * The function should be registered through TVMFFITypeRegisterAttr via 
reflection::TypeAttrDef.
-   */
-  kTVMFFISEqHashKindCustomTreeNode = 6,
 #ifdef __cplusplus
 };
 #else
diff --git a/ffi/src/ffi/reflection/structural_equal.cc 
b/ffi/src/ffi/reflection/structural_equal.cc
index fff6fd0c1d..e44a0c3256 100644
--- a/ffi/src/ffi/reflection/structural_equal.cc
+++ b/ffi/src/ffi/reflection/structural_equal.cc
@@ -130,8 +130,10 @@ class StructEqualHandler {
       }
     }
 
+    static reflection::TypeAttrColumn custom_s_equal = 
reflection::TypeAttrColumn("__s_equal__");
+
     bool success = true;
-    if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) {
+    if (custom_s_equal[type_info->type_index] == nullptr) {
       // We recursively compare the fields the object
       ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo* 
field_info) {
         // skip fields that are marked as structural eq hash ignore
@@ -165,7 +167,6 @@ class StructEqualHandler {
         }
       });
     } else {
-      static reflection::TypeAttrColumn custom_s_equal = 
reflection::TypeAttrColumn("__s_equal__");
       // run custom equal function defined via __s_equal__ type attribute
       if (s_equal_callback_ == nullptr) {
         s_equal_callback_ = ffi::Function::FromTyped(
@@ -191,9 +192,6 @@ class StructEqualHandler {
               return success;
             });
       }
-      TVM_FFI_ICHECK(custom_s_equal[type_info->type_index] != nullptr)
-          << "TypeAttr `__s_equal__` is not registered for type `" << 
String(type_info->type_key)
-          << "`";
       success = custom_s_equal[type_info->type_index]
                     .cast<ffi::Function>()(lhs, rhs, s_equal_callback_)
                     .cast<bool>();
diff --git a/ffi/src/ffi/reflection/structural_hash.cc 
b/ffi/src/ffi/reflection/structural_hash.cc
index 78cc294a64..e8ffcf6d2a 100644
--- a/ffi/src/ffi/reflection/structural_hash.cc
+++ b/ffi/src/ffi/reflection/structural_hash.cc
@@ -113,9 +113,11 @@ class StructuralHashHandler {
       return it->second;
     }
 
+    static reflection::TypeAttrColumn custom_s_hash = 
reflection::TypeAttrColumn("__s_hash__");
+
     // compute the hash value
     uint64_t hash_value = obj->GetTypeKeyHash();
-    if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) {
+    if (custom_s_hash[type_info->type_index] == nullptr) {
       // go over the content and hash the fields
       ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) {
         // skip fields that are marked as structural eq hash ignore
@@ -135,22 +137,19 @@ class StructuralHashHandler {
         }
       });
     } else {
-      static reflection::TypeAttrColumn custom_s_hash = 
reflection::TypeAttrColumn("__s_hash__");
-      TVM_FFI_ICHECK(custom_s_hash[type_info->type_index] != nullptr)
-          << "TypeAttr `__s_hash__` is not registered for type `" << 
String(type_info->type_key)
-          << "`";
       if (s_hash_callback_ == nullptr) {
-        s_hash_callback_ = ffi::Function::FromTyped([this](AnyView val, bool 
def_region) {
-          if (def_region) {
-            bool allow_free_var = true;
-            std::swap(allow_free_var, map_free_vars_);
-            uint64_t hash_value = HashAny(val);
-            std::swap(allow_free_var, map_free_vars_);
-            return hash_value;
-          } else {
-            return HashAny(val);
-          }
-        });
+        s_hash_callback_ =
+            ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash, 
bool def_region) {
+              if (def_region) {
+                bool allow_free_var = true;
+                std::swap(allow_free_var, map_free_vars_);
+                uint64_t hash_value = HashAny(val);
+                std::swap(allow_free_var, map_free_vars_);
+                return details::StableHashCombine(init_hash, hash_value);
+              } else {
+                return details::StableHashCombine(init_hash, HashAny(val));
+              }
+            });
       }
       hash_value = custom_s_hash[type_info->type_index]
                        .cast<ffi::Function>()(obj, hash_value, 
s_hash_callback_)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 0f6bce687b..31c26c0a33 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -465,6 +465,11 @@ class GlobalVarNode : public RelaxExprNode {
   static void RegisterReflection() {
     namespace refl = tvm::ffi::reflection;
     refl::ObjectDef<GlobalVarNode>().def_ro("name_hint", 
&GlobalVarNode::name_hint);
+    // register custom structural equal and hash.
+    // skip checking struct_info_ for now
+    refl::TypeAttrDef<GlobalVarNode>()
+        .def("__s_equal__", &GlobalVarNode::SEqual)
+        .def("__s_hash__", &GlobalVarNode::SHash);
   }
 
   bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
@@ -477,6 +482,16 @@ class GlobalVarNode : public RelaxExprNode {
     hash_reduce.FreeVarHashImpl(this);
   }
 
+  bool SEqual(const GlobalVarNode* other,
+              ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) 
const {
+    return equal(name_hint, other->name_hint, false, "name_hint");
+  }
+
+  uint64_t SHash(uint64_t init_hash,
+                 ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash) 
const {
+    return hash(name_hint, init_hash, false);
+  }
+
   static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindFreeVar;
   static constexpr const char* _type_key = "ir.GlobalVar";
   TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelaxExprNode);
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index bde289b5d0..66c26b0629 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -150,8 +150,8 @@ class IRModuleNode : public Object {
 
   TVM_DLL bool SEqual(const IRModuleNode* other,
                       ffi::TypedFunction<bool(AnyView, AnyView, bool, 
AnyView)> equal) const;
-  TVM_DLL uint64_t SHash(uint64_t type_key_hash,
-                         ffi::TypedFunction<uint64_t(AnyView, bool)> hash) 
const;
+  TVM_DLL uint64_t SHash(uint64_t init_hash,
+                         ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> 
hash) const;
 
   /*!
    * \brief Add a function to the global environment.
@@ -246,7 +246,7 @@ class IRModuleNode : public Object {
   TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
 
   static constexpr const char* _type_key = "ir.IRModule";
-  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindCustomTreeNode;
+  static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = 
kTVMFFISEqHashKindTreeNode;
   static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const bool _type_has_method_shash_reduce = true;
   TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index 49df3baff8..472d06ada8 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -107,6 +107,11 @@ class PrimFuncNode : public BaseFuncNode {
         .def_ro("ret_type", &PrimFuncNode::ret_type)
         .def_ro("buffer_map", &PrimFuncNode::buffer_map)
         .def_ro("body", &PrimFuncNode::body);
+    // register custom structural equal and hash.
+    // skip checking struct_info_ for now
+    refl::TypeAttrDef<PrimFuncNode>()
+        .def("__s_equal__", &PrimFuncNode::SEqual)
+        .def("__s_hash__", &PrimFuncNode::SHash);
   }
 
   bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
@@ -123,6 +128,26 @@ class PrimFuncNode : public BaseFuncNode {
     hash_reduce(body);
     hash_reduce(attrs);
   }
+
+  bool SEqual(const PrimFuncNode* other,
+              ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal) 
const {
+    return equal(params, other->params, true, "params") &&
+           equal(buffer_map, other->buffer_map, false, "buffer_map") &&
+           equal(ret_type, other->ret_type, false, "ret_type") &&
+           equal(body, other->body, false, "body") && equal(attrs, 
other->attrs, false, "attrs");
+  }
+
+  uint64_t SHash(uint64_t init_hash,
+                 ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash) 
const {
+    uint64_t hash_value = init_hash;
+    hash_value = hash(params, hash_value, true);
+    hash_value = hash(buffer_map, hash_value, false);
+    hash_value = hash(ret_type, hash_value, false);
+    hash_value = hash(body, hash_value, false);
+    hash_value = hash(attrs, hash_value, false);
+    return hash_value;
+  }
+
   /*!
    * \brief Return the derived function annotation of this function.
    *
diff --git a/src/ir/module.cc b/src/ir/module.cc
index e74f41e698..f178747246 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -151,11 +151,11 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) 
const {
   hash_reduce(this->global_infos);
 }
 
-uint64_t IRModuleNode::SHash(uint64_t type_key_hash,
-                             ffi::TypedFunction<uint64_t(AnyView, bool)> hash) 
const {
-  uint64_t hash_value = type_key_hash;
-  hash_value = tvm::ffi::details::StableHashCombine(hash_value, 
hash(this->attrs, false));
-  hash_value = tvm::ffi::details::StableHashCombine(hash_value, 
hash(this->global_infos, false));
+uint64_t IRModuleNode::SHash(uint64_t init_hash,
+                             ffi::TypedFunction<uint64_t(AnyView, uint64_t, 
bool)> hash) const {
+  uint64_t hash_value = init_hash;
+  hash_value = hash(this->attrs, hash_value, false);
+  hash_value = hash(this->global_infos, hash_value, false);
 
   // hash the functions.
   using KV = std::tuple<std::string, ObjectRef, ObjectRef>;
@@ -166,17 +166,15 @@ uint64_t IRModuleNode::SHash(uint64_t type_key_hash,
   // sort by the hash key of the keys.
   std::sort(temp.begin(), temp.end(),
             [](const KV& lhs, const KV& rhs) { return std::get<0>(lhs) < 
std::get<0>(rhs); });
-  hash_value = tvm::ffi::details::StableHashCombine(hash_value, 
static_cast<uint64_t>(temp.size()));
+  hash_value = hash(static_cast<uint64_t>(temp.size()), hash_value, false);
   // first need to define the GlobalVar in the order of the keys
   for (size_t i = 0; i < temp.size(); ++i) {
-    hash_value = tvm::ffi::details::StableHashCombine(hash_value, 
hash(std::get<1>(temp[i]), true));
+    hash_value = hash(std::get<1>(temp[i]), hash_value, true);
   }
   // hash the name and content
   for (size_t i = 0; i < temp.size(); ++i) {
-    hash_value =
-        tvm::ffi::details::StableHashCombine(hash_value, 
hash(std::get<0>(temp[i]), false));
-    hash_value =
-        tvm::ffi::details::StableHashCombine(hash_value, 
hash(std::get<2>(temp[i]), false));
+    hash_value = hash(std::get<0>(temp[i]), hash_value, false);
+    hash_value = hash(std::get<2>(temp[i]), hash_value, false);
   }
   return hash_value;
 }

Reply via email to