This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch sequal-upgrade
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/sequal-upgrade by this push:
new 54aea761a7 streamline custom override
54aea761a7 is described below
commit 54aea761a70ff638dafa2b0c3e95a6e85e4c8f96
Author: tqchen <[email protected]>
AuthorDate: Sat Jul 26 16:33:10 2025 -0400
streamline custom override
---
ffi/include/tvm/ffi/c_api.h | 21 --------------------
ffi/src/ffi/reflection/structural_equal.cc | 8 +++-----
ffi/src/ffi/reflection/structural_hash.cc | 31 +++++++++++++++---------------
include/tvm/ir/expr.h | 15 +++++++++++++++
include/tvm/ir/module.h | 6 +++---
include/tvm/tir/function.h | 25 ++++++++++++++++++++++++
src/ir/module.cc | 20 +++++++++----------
7 files changed, 70 insertions(+), 56 deletions(-)
diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index e2de610a5d..60743b82c6 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -424,27 +424,6 @@ typedef enum {
* is only an unique copy of each value.
*/
kTVMFFISEqHashKindUniqueInstance = 5,
- /*!
- * \brief provide custom __s_equal__ and __s_hash__ functions through
TypeAttrColumn.
- *
- * The function signatures are(defined via ffi::Function)
- *
- * \code
- * bool __s_equal__(
- * ObjectRefType self, ObjectRefType other,
- * ffi::TypedFunction<bool(AnyView, AnyView, bool def_region, string
field_name)> cmp,
- * );
- *
- * uint64_t __s_hash__(
- * ObjectRefType self, uint64_t type_key_hash,
- * ffi::TypedFunction<uint64_t(AnyView, bool def_region)> hash
- * );
- * \endcode
- *
- * Where the extra string field in cmp is the name of the field that is
being compared.
- * The function should be registered through TVMFFITypeRegisterAttr via
reflection::TypeAttrDef.
- */
- kTVMFFISEqHashKindCustomTreeNode = 6,
#ifdef __cplusplus
};
#else
diff --git a/ffi/src/ffi/reflection/structural_equal.cc
b/ffi/src/ffi/reflection/structural_equal.cc
index fff6fd0c1d..e44a0c3256 100644
--- a/ffi/src/ffi/reflection/structural_equal.cc
+++ b/ffi/src/ffi/reflection/structural_equal.cc
@@ -130,8 +130,10 @@ class StructEqualHandler {
}
}
+ static reflection::TypeAttrColumn custom_s_equal =
reflection::TypeAttrColumn("__s_equal__");
+
bool success = true;
- if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) {
+ if (custom_s_equal[type_info->type_index] == nullptr) {
// We recursively compare the fields the object
ForEachFieldInfoWithEarlyStop(type_info, [&](const TVMFFIFieldInfo*
field_info) {
// skip fields that are marked as structural eq hash ignore
@@ -165,7 +167,6 @@ class StructEqualHandler {
}
});
} else {
- static reflection::TypeAttrColumn custom_s_equal =
reflection::TypeAttrColumn("__s_equal__");
// run custom equal function defined via __s_equal__ type attribute
if (s_equal_callback_ == nullptr) {
s_equal_callback_ = ffi::Function::FromTyped(
@@ -191,9 +192,6 @@ class StructEqualHandler {
return success;
});
}
- TVM_FFI_ICHECK(custom_s_equal[type_info->type_index] != nullptr)
- << "TypeAttr `__s_equal__` is not registered for type `" <<
String(type_info->type_key)
- << "`";
success = custom_s_equal[type_info->type_index]
.cast<ffi::Function>()(lhs, rhs, s_equal_callback_)
.cast<bool>();
diff --git a/ffi/src/ffi/reflection/structural_hash.cc
b/ffi/src/ffi/reflection/structural_hash.cc
index 78cc294a64..e8ffcf6d2a 100644
--- a/ffi/src/ffi/reflection/structural_hash.cc
+++ b/ffi/src/ffi/reflection/structural_hash.cc
@@ -113,9 +113,11 @@ class StructuralHashHandler {
return it->second;
}
+ static reflection::TypeAttrColumn custom_s_hash =
reflection::TypeAttrColumn("__s_hash__");
+
// compute the hash value
uint64_t hash_value = obj->GetTypeKeyHash();
- if (structural_eq_hash_kind != kTVMFFISEqHashKindCustomTreeNode) {
+ if (custom_s_hash[type_info->type_index] == nullptr) {
// go over the content and hash the fields
ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) {
// skip fields that are marked as structural eq hash ignore
@@ -135,22 +137,19 @@ class StructuralHashHandler {
}
});
} else {
- static reflection::TypeAttrColumn custom_s_hash =
reflection::TypeAttrColumn("__s_hash__");
- TVM_FFI_ICHECK(custom_s_hash[type_info->type_index] != nullptr)
- << "TypeAttr `__s_hash__` is not registered for type `" <<
String(type_info->type_key)
- << "`";
if (s_hash_callback_ == nullptr) {
- s_hash_callback_ = ffi::Function::FromTyped([this](AnyView val, bool
def_region) {
- if (def_region) {
- bool allow_free_var = true;
- std::swap(allow_free_var, map_free_vars_);
- uint64_t hash_value = HashAny(val);
- std::swap(allow_free_var, map_free_vars_);
- return hash_value;
- } else {
- return HashAny(val);
- }
- });
+ s_hash_callback_ =
+ ffi::Function::FromTyped([this](AnyView val, uint64_t init_hash,
bool def_region) {
+ if (def_region) {
+ bool allow_free_var = true;
+ std::swap(allow_free_var, map_free_vars_);
+ uint64_t hash_value = HashAny(val);
+ std::swap(allow_free_var, map_free_vars_);
+ return details::StableHashCombine(init_hash, hash_value);
+ } else {
+ return details::StableHashCombine(init_hash, HashAny(val));
+ }
+ });
}
hash_value = custom_s_hash[type_info->type_index]
.cast<ffi::Function>()(obj, hash_value,
s_hash_callback_)
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 0f6bce687b..31c26c0a33 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -465,6 +465,11 @@ class GlobalVarNode : public RelaxExprNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GlobalVarNode>().def_ro("name_hint",
&GlobalVarNode::name_hint);
+ // register custom structural equal and hash.
+ // skip checking struct_info_ for now
+ refl::TypeAttrDef<GlobalVarNode>()
+ .def("__s_equal__", &GlobalVarNode::SEqual)
+ .def("__s_hash__", &GlobalVarNode::SHash);
}
bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
@@ -477,6 +482,16 @@ class GlobalVarNode : public RelaxExprNode {
hash_reduce.FreeVarHashImpl(this);
}
+ bool SEqual(const GlobalVarNode* other,
+ ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal)
const {
+ return equal(name_hint, other->name_hint, false, "name_hint");
+ }
+
+ uint64_t SHash(uint64_t init_hash,
+ ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash)
const {
+ return hash(name_hint, init_hash, false);
+ }
+
static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindFreeVar;
static constexpr const char* _type_key = "ir.GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelaxExprNode);
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index bde289b5d0..66c26b0629 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -150,8 +150,8 @@ class IRModuleNode : public Object {
TVM_DLL bool SEqual(const IRModuleNode* other,
ffi::TypedFunction<bool(AnyView, AnyView, bool,
AnyView)> equal) const;
- TVM_DLL uint64_t SHash(uint64_t type_key_hash,
- ffi::TypedFunction<uint64_t(AnyView, bool)> hash)
const;
+ TVM_DLL uint64_t SHash(uint64_t init_hash,
+ ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)>
hash) const;
/*!
* \brief Add a function to the global environment.
@@ -246,7 +246,7 @@ class IRModuleNode : public Object {
TVM_OBJECT_ENABLE_SCRIPT_PRINTER();
static constexpr const char* _type_key = "ir.IRModule";
- static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindCustomTreeNode;
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index 49df3baff8..472d06ada8 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -107,6 +107,11 @@ class PrimFuncNode : public BaseFuncNode {
.def_ro("ret_type", &PrimFuncNode::ret_type)
.def_ro("buffer_map", &PrimFuncNode::buffer_map)
.def_ro("body", &PrimFuncNode::body);
+ // register custom structural equal and hash.
+ // skip checking struct_info_ for now
+ refl::TypeAttrDef<PrimFuncNode>()
+ .def("__s_equal__", &PrimFuncNode::SEqual)
+ .def("__s_hash__", &PrimFuncNode::SHash);
}
bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
@@ -123,6 +128,26 @@ class PrimFuncNode : public BaseFuncNode {
hash_reduce(body);
hash_reduce(attrs);
}
+
+ bool SEqual(const PrimFuncNode* other,
+ ffi::TypedFunction<bool(AnyView, AnyView, bool, AnyView)> equal)
const {
+ return equal(params, other->params, true, "params") &&
+ equal(buffer_map, other->buffer_map, false, "buffer_map") &&
+ equal(ret_type, other->ret_type, false, "ret_type") &&
+ equal(body, other->body, false, "body") && equal(attrs,
other->attrs, false, "attrs");
+ }
+
+ uint64_t SHash(uint64_t init_hash,
+ ffi::TypedFunction<uint64_t(AnyView, uint64_t, bool)> hash)
const {
+ uint64_t hash_value = init_hash;
+ hash_value = hash(params, hash_value, true);
+ hash_value = hash(buffer_map, hash_value, false);
+ hash_value = hash(ret_type, hash_value, false);
+ hash_value = hash(body, hash_value, false);
+ hash_value = hash(attrs, hash_value, false);
+ return hash_value;
+ }
+
/*!
* \brief Return the derived function annotation of this function.
*
diff --git a/src/ir/module.cc b/src/ir/module.cc
index e74f41e698..f178747246 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -151,11 +151,11 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce)
const {
hash_reduce(this->global_infos);
}
-uint64_t IRModuleNode::SHash(uint64_t type_key_hash,
- ffi::TypedFunction<uint64_t(AnyView, bool)> hash)
const {
- uint64_t hash_value = type_key_hash;
- hash_value = tvm::ffi::details::StableHashCombine(hash_value,
hash(this->attrs, false));
- hash_value = tvm::ffi::details::StableHashCombine(hash_value,
hash(this->global_infos, false));
+uint64_t IRModuleNode::SHash(uint64_t init_hash,
+ ffi::TypedFunction<uint64_t(AnyView, uint64_t,
bool)> hash) const {
+ uint64_t hash_value = init_hash;
+ hash_value = hash(this->attrs, hash_value, false);
+ hash_value = hash(this->global_infos, hash_value, false);
// hash the functions.
using KV = std::tuple<std::string, ObjectRef, ObjectRef>;
@@ -166,17 +166,15 @@ uint64_t IRModuleNode::SHash(uint64_t type_key_hash,
// sort by the hash key of the keys.
std::sort(temp.begin(), temp.end(),
[](const KV& lhs, const KV& rhs) { return std::get<0>(lhs) <
std::get<0>(rhs); });
- hash_value = tvm::ffi::details::StableHashCombine(hash_value,
static_cast<uint64_t>(temp.size()));
+ hash_value = hash(static_cast<uint64_t>(temp.size()), hash_value, false);
// first need to define the GlobalVar in the order of the keys
for (size_t i = 0; i < temp.size(); ++i) {
- hash_value = tvm::ffi::details::StableHashCombine(hash_value,
hash(std::get<1>(temp[i]), true));
+ hash_value = hash(std::get<1>(temp[i]), hash_value, true);
}
// hash the name and content
for (size_t i = 0; i < temp.size(); ++i) {
- hash_value =
- tvm::ffi::details::StableHashCombine(hash_value,
hash(std::get<0>(temp[i]), false));
- hash_value =
- tvm::ffi::details::StableHashCombine(hash_value,
hash(std::get<2>(temp[i]), false));
+ hash_value = hash(std::get<0>(temp[i]), hash_value, false);
+ hash_value = hash(std::get<2>(temp[i]), hash_value, false);
}
return hash_value;
}