This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e22a2d5b9f [IR] Enhance IRModule SEqual/SHash to support cross 
function calls (#14289)
e22a2d5b9f is described below

commit e22a2d5b9f294474b3fadb5cd8a9c429a17f7943
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Mar 15 09:31:38 2023 +0800

    [IR] Enhance IRModule SEqual/SHash to support cross function calls (#14289)
    
    As GlobalVars are defined under IRModule, we need to define it during the
    IRModule SEqual/SHash step via (`DefEqual` and `DefHash`).
---
 src/ir/module.cc | 68 +++++++++++++++++++++++++++++++++-----------------------
 1 file changed, 40 insertions(+), 28 deletions(-)

diff --git a/src/ir/module.cc b/src/ir/module.cc
index 42ced96120..7a973da29d 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -63,67 +63,79 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
 }
 
 bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer 
equal) const {
-  if (functions.size() != other->functions.size()) return false;
   if (!equal(this->attrs, other->attrs)) return false;
-  if (equal.IsPathTracingEnabled()) {
-    const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
-    for (const auto& kv : this->functions) {
-      if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
+
+  if (functions.size() != other->functions.size()) return false;
+  // Update GlobalVar remap
+  for (const auto& gv : this->GetGlobalVars()) {
+    if (!other->ContainGlobalVar(gv->name_hint)) return false;
+    if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false;
+  }
+  // Checking functions
+  for (const auto& kv : this->functions) {
+    if (equal.IsPathTracingEnabled()) {
+      const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
       ObjectPathPair func_paths = 
{obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first),
                                    obj_path_pair->rhs_path->Attr("functions")
                                        
->MapValue(other->GetGlobalVar(kv.first->name_hint))};
       if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) 
return false;
+    } else {
+      if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
     }
-    if (type_definitions.size() != other->type_definitions.size()) return 
false;
-    for (const auto& kv : this->type_definitions) {
-      if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
-      ObjectPathPair type_def_paths = {
+  }
+
+  if (type_definitions.size() != other->type_definitions.size()) return false;
+  // Update GlobalTypeVar remap
+  for (const auto& gtv : this->GetGlobalTypeVars()) {
+    if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false;
+    if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return 
false;
+  }
+  // Checking type_definitions
+  for (const auto& kv : this->type_definitions) {
+    if (equal.IsPathTracingEnabled()) {
+      const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
+      ObjectPathPair type_paths = {
           
obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first),
           obj_path_pair->rhs_path->Attr("type_definitions")
               ->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))};
-      if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), 
type_def_paths))
-        return false;
+      if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), 
type_paths)) return false;
+    } else {
+      if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return 
false;
     }
-    return true;
-  }
-  for (const auto& kv : this->functions) {
-    if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
-    if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
-  }
-  if (type_definitions.size() != other->type_definitions.size()) return false;
-  for (const auto& kv : this->type_definitions) {
-    if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
-    if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return 
false;
   }
   return true;
 }
 
 void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
-  using KV = std::pair<std::string, ObjectRef>;
+  using KV = std::tuple<std::string, ObjectRef, ObjectRef>;
   // hash the functions.
   std::vector<KV> temp;
 
   auto reduce_temp = [&]() {
     // sort by the hash key of the keys.
     std::sort(temp.begin(), temp.end(),
-              [](const KV& lhs, const KV& rhs) { return lhs.first < rhs.first; 
});
+              [](const KV& lhs, const KV& rhs) { return std::get<0>(lhs) < 
std::get<0>(rhs); });
 
     hash_reduce(static_cast<uint64_t>(temp.size()));
-    // hash the content
+    // Defhash the GlobalVar/GlobalTypeVar
+    for (size_t i = 0; i < temp.size(); ++i) {
+      hash_reduce.DefHash(std::get<1>(temp[i]));
+    }
+    // hash the name and content
     for (size_t i = 0; i < temp.size(); ++i) {
-      hash_reduce(temp[i].first);
-      hash_reduce(temp[i].second);
+      hash_reduce(std::get<0>(temp[i]));
+      hash_reduce(std::get<2>(temp[i]));
     }
   };
 
   for (const auto& kv : this->functions) {
-    temp.emplace_back(kv.first->name_hint, kv.second);
+    temp.emplace_back(kv.first->name_hint, kv.first, kv.second);
   }
   reduce_temp();
 
   temp.clear();
   for (const auto& kv : this->type_definitions) {
-    temp.emplace_back(kv.first->name_hint, kv.second);
+    temp.emplace_back(kv.first->name_hint, kv.first, kv.second);
   }
   reduce_temp();
   hash_reduce(this->attrs);

Reply via email to