This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e22a2d5b9f [IR] Enhance IRModule SEqual/SHash to support cross
function calls (#14289)
e22a2d5b9f is described below
commit e22a2d5b9f294474b3fadb5cd8a9c429a17f7943
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Mar 15 09:31:38 2023 +0800
[IR] Enhance IRModule SEqual/SHash to support cross function calls (#14289)
As GlobalVars are defined under IRModule, we need to define it during the
IRModule SEqual/SHash step via (`DefEqual` and `DefHash`).
---
src/ir/module.cc | 68 +++++++++++++++++++++++++++++++++-----------------------
1 file changed, 40 insertions(+), 28 deletions(-)
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 42ced96120..7a973da29d 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -63,67 +63,79 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
}
bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer
equal) const {
- if (functions.size() != other->functions.size()) return false;
if (!equal(this->attrs, other->attrs)) return false;
- if (equal.IsPathTracingEnabled()) {
- const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
- for (const auto& kv : this->functions) {
- if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
+
+ if (functions.size() != other->functions.size()) return false;
+ // Update GlobalVar remap
+ for (const auto& gv : this->GetGlobalVars()) {
+ if (!other->ContainGlobalVar(gv->name_hint)) return false;
+ if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false;
+ }
+ // Checking functions
+ for (const auto& kv : this->functions) {
+ if (equal.IsPathTracingEnabled()) {
+ const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
ObjectPathPair func_paths =
{obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first),
obj_path_pair->rhs_path->Attr("functions")
->MapValue(other->GetGlobalVar(kv.first->name_hint))};
if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths))
return false;
+ } else {
+ if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
}
- if (type_definitions.size() != other->type_definitions.size()) return
false;
- for (const auto& kv : this->type_definitions) {
- if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
- ObjectPathPair type_def_paths = {
+ }
+
+ if (type_definitions.size() != other->type_definitions.size()) return false;
+ // Update GlobalTypeVar remap
+ for (const auto& gtv : this->GetGlobalTypeVars()) {
+ if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false;
+ if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return
false;
+ }
+ // Checking type_definitions
+ for (const auto& kv : this->type_definitions) {
+ if (equal.IsPathTracingEnabled()) {
+ const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
+ ObjectPathPair type_paths = {
obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first),
obj_path_pair->rhs_path->Attr("type_definitions")
->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))};
- if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint),
type_def_paths))
- return false;
+ if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint),
type_paths)) return false;
+ } else {
+ if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return
false;
}
- return true;
- }
- for (const auto& kv : this->functions) {
- if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
- if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
- }
- if (type_definitions.size() != other->type_definitions.size()) return false;
- for (const auto& kv : this->type_definitions) {
- if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
- if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return
false;
}
return true;
}
void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
- using KV = std::pair<std::string, ObjectRef>;
+ using KV = std::tuple<std::string, ObjectRef, ObjectRef>;
// hash the functions.
std::vector<KV> temp;
auto reduce_temp = [&]() {
// sort by the hash key of the keys.
std::sort(temp.begin(), temp.end(),
- [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first;
});
+ [](const KV& lhs, const KV& rhs) { return std::get<0>(lhs) <
std::get<0>(rhs); });
hash_reduce(static_cast<uint64_t>(temp.size()));
- // hash the content
+ // Defhash the GlobalVar/GlobalTypeVar
+ for (size_t i = 0; i < temp.size(); ++i) {
+ hash_reduce.DefHash(std::get<1>(temp[i]));
+ }
+ // hash the name and content
for (size_t i = 0; i < temp.size(); ++i) {
- hash_reduce(temp[i].first);
- hash_reduce(temp[i].second);
+ hash_reduce(std::get<0>(temp[i]));
+ hash_reduce(std::get<2>(temp[i]));
}
};
for (const auto& kv : this->functions) {
- temp.emplace_back(kv.first->name_hint, kv.second);
+ temp.emplace_back(kv.first->name_hint, kv.first, kv.second);
}
reduce_temp();
temp.clear();
for (const auto& kv : this->type_definitions) {
- temp.emplace_back(kv.first->name_hint, kv.second);
+ temp.emplace_back(kv.first->name_hint, kv.first, kv.second);
}
reduce_temp();
hash_reduce(this->attrs);