zxybazh commented on code in PR #13050:
URL: https://github.com/apache/tvm/pull/13050#discussion_r995346478


##########
src/meta_schedule/search_strategy/evolutionary_search.cc:
##########
@@ -322,6 +333,8 @@ class EvolutionarySearchNode : public SearchStrategyNode {
     /*! \brief An interface method to be called by it's counterpart in 
EvolutionarySearchNode */
     inline void NotifyRunnerResults(const Array<MeasureCandidate>& 
measure_candidates,
                                     const Array<RunnerResult>& results);
+    /*! \brief Compute the hash for the given module */

Review Comment:
   Complete the params & return value.



##########
include/tvm/meta_schedule/database.h:
##########
@@ -404,23 +438,27 @@ class PyDatabaseNode : public DatabaseNode {
  */
 class Database : public runtime::ObjectRef {
  public:
-  /*! An in-memory database. */
-  TVM_DLL static Database MemoryDatabase();
+  /*! An in-memory database.

Review Comment:
   Nit, Add `\brief` to align up.



##########
src/relay/backend/task_extraction.cc:
##########
@@ -42,26 +46,35 @@ Array<meta_schedule::ExtractedTask> ExtractTask(IRModule 
mod, Target target,
   mod = transform::Sequential(pass_seqs)(std::move(mod));
 
   std::vector<ExtractedTask> tasks;
-  std::unordered_map<tec::CCacheKey, ExtractedTask> cache;
+
+  auto mod_eq = meta_schedule::ModuleEquality::Create(mod_eq_name);
+
+  std::unordered_map<IRModule, ExtractedTask, ModuleHash, ModuleEqual> cache(
+      /*bucket_count*/ 0, ModuleHash(*mod_eq), ModuleEqual(*mod_eq));
+
   PostOrderVisit(mod->Lookup("main"), [&target, &tasks, &cache, 
&tir_converter](const Expr& exp) {
     if (exp->IsInstance<FunctionNode>()) {
       Function relay_func = Downcast<Function>(exp);
       if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) {
         return;
       }
-      tec::CCacheKey cache_key(relay_func, target);
-      auto it = cache.find(cache_key);
-      if (it != cache.end()) {
-        it->second->weight += 1;
-        return;
-      }
+
       auto [inputs_outputs, constants, fused_name] =
           tec::LowerTECompute(relay_func, target, /*return_inputs=*/true);
+
       if (Optional<tir::PrimFunc> f = tir_converter(inputs_outputs, 
constants)) {
+        IRModule tir_mod = PrimFuncToIRModule(f.value());
+
+        auto it = cache.find(tir_mod);
+        if (it != cache.end()) {
+          it->second->weight += 1;
+          return;
+        }
+
         IRModule relay_mod({{GlobalVar(fused_name), relay_func}});
-        ExtractedTask task(fused_name, relay_mod, target, 
{PrimFuncToIRModule(f.value())}, 1);
+        ExtractedTask task(fused_name, relay_mod, target, {tir_mod}, 1);
         tasks.push_back(task);
-        cache.emplace(cache_key, task);
+        cache.emplace(tir_mod, task);

Review Comment:
   Does it matter which TIR mod we use as the extracted task? When using 
structural equal I think the traces can be shared among them w/o any problem. 
However, if we are using some customized equality comparison, for example, 
judging based on the anchor op, where `conv2d` and `conv2d+add` are considered 
equal, do we need to make sure the extracted task is `conv2d` instead of 
`conv2d+add` so that there won't be an extra `AutoInline` generated for the 
`add` block that doesn't apply to `conv2d` task?



##########
src/meta_schedule/database/database.cc:
##########
@@ -25,8 +26,8 @@ namespace meta_schedule {
 
 Workload::Workload(IRModule mod) {

Review Comment:
   Do you think we can allow `mod_eq_name` as an argument here so that the 
shash could be customized?



##########
include/tvm/meta_schedule/database.h:
##########
@@ -168,8 +176,16 @@ class TuningRecord : public runtime::ObjectRef {
 /* \brief The abstract interface of database. */
 class DatabaseNode : public runtime::Object {
  public:
+  /*!
+   * \brief Constructor
+   * \param mod_eq_name A string to specify the module equality testing and 
hashing method.
+   *  It must be one of the followings:
+   *    - "structural": Use StructuralEqual/Hash
+   */
+  explicit DatabaseNode(String mod_eq_name = "structural");
+
   /*! \brief Default destructor */
-  virtual ~DatabaseNode() = default;

Review Comment:
   May I ask why this is removed to `database.cc` instead?



##########
src/relay/backend/task_extraction.cc:
##########
@@ -42,26 +46,35 @@ Array<meta_schedule::ExtractedTask> ExtractTask(IRModule 
mod, Target target,
   mod = transform::Sequential(pass_seqs)(std::move(mod));
 
   std::vector<ExtractedTask> tasks;
-  std::unordered_map<tec::CCacheKey, ExtractedTask> cache;
+
+  auto mod_eq = meta_schedule::ModuleEquality::Create(mod_eq_name);
+
+  std::unordered_map<IRModule, ExtractedTask, ModuleHash, ModuleEqual> cache(
+      /*bucket_count*/ 0, ModuleHash(*mod_eq), ModuleEqual(*mod_eq));
+
   PostOrderVisit(mod->Lookup("main"), [&target, &tasks, &cache, 
&tir_converter](const Expr& exp) {
     if (exp->IsInstance<FunctionNode>()) {
       Function relay_func = Downcast<Function>(exp);
       if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) {
         return;
       }
-      tec::CCacheKey cache_key(relay_func, target);
-      auto it = cache.find(cache_key);
-      if (it != cache.end()) {
-        it->second->weight += 1;
-        return;
-      }
+
       auto [inputs_outputs, constants, fused_name] =
           tec::LowerTECompute(relay_func, target, /*return_inputs=*/true);
+
       if (Optional<tir::PrimFunc> f = tir_converter(inputs_outputs, 
constants)) {
+        IRModule tir_mod = PrimFuncToIRModule(f.value());
+
+        auto it = cache.find(tir_mod);

Review Comment:
   Let's add a comment for this note.



##########
src/meta_schedule/database/database.cc:
##########
@@ -59,12 +60,8 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) {
       String b64_mod = Downcast<String>(json_array->at(1));
       std::string json_mod = Base64Decode(b64_mod);
       mod = Downcast<IRModule>(LoadJSON(json_mod));
+      std::stringstream(str_shash) >> shash;
     }
-    // Verify SHash(mod) == shash
-    shash = tvm::StructuralHash()(mod);
-    String recalc_shash = SHash2Str(shash);
-    CHECK_EQ(recalc_shash, str_shash) << "ValueError: Structural hash changed. 
Given: " << str_shash
-                                      << "; Recalculated: " << recalc_shash;

Review Comment:
   I think the removal make sense. On the other hand, we are always using 
structural hash as the `shash`. I think we may not really need the 
functionality to specify the `shash` direcly in constructor.
   - What about we construct it from the mod and `mod_eq_name` string so that 
we can obtain the customized hash result?
   - Maybe we can store the `mod_eq_name` in `Workload` so that when we parse 
it we can still check the shash results?
   
   What do you think?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to