This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 6c662eb631 [Unity] Use custom hash in `BlockBuilder` to avoid hashing 
large constants (#14675)
6c662eb631 is described below

commit 6c662eb631d5c48d66873944920d3b20a31aaf91
Author: masahi <[email protected]>
AuthorDate: Thu Apr 20 21:14:09 2023 +0900

    [Unity] Use custom hash in `BlockBuilder` to avoid hashing large constants 
(#14675)
    
    * use custom hash to avoid hashing constants
    
    * add comments
    
    * refactor
    
    * format
---
 src/meta_schedule/module_equality.cc | 13 -------------
 src/node/ndarray_hash_equal.h        |  7 +++++++
 src/node/structural_hash.cc          | 10 ++++++++++
 src/relax/ir/block_builder.cc        | 18 ++++++++++++++++--
 4 files changed, 33 insertions(+), 15 deletions(-)

diff --git a/src/meta_schedule/module_equality.cc 
b/src/meta_schedule/module_equality.cc
index 0997aab9b6..b4f37779dd 100644
--- a/src/meta_schedule/module_equality.cc
+++ b/src/meta_schedule/module_equality.cc
@@ -53,19 +53,6 @@ class SEqualHandlerIgnoreNDArray : public 
SEqualHandlerDefault {
   }
 };
 
-class SHashHandlerIgnoreNDArray : public SHashHandlerDefault {
- protected:
-  void DispatchSHash(const ObjectRef& object, bool map_free_vars) override {
-    ICHECK(object.defined());
-    if (auto ndarray = object.as<runtime::NDArray::Container>()) {
-      SHashReducer hash_reduce(this, map_free_vars);
-      NDArrayHash(ndarray, &hash_reduce, false);
-    } else {
-      SHashHandlerDefault::DispatchSHash(object, map_free_vars);
-    }
-  }
-};
-
 class ModuleEqualityIgnoreNDArray : public ModuleEquality {
  public:
   size_t Hash(IRModule mod) const { return 
SHashHandlerIgnoreNDArray().Hash(mod, false); }
diff --git a/src/node/ndarray_hash_equal.h b/src/node/ndarray_hash_equal.h
index d674018fbd..b5639f524b 100644
--- a/src/node/ndarray_hash_equal.h
+++ b/src/node/ndarray_hash_equal.h
@@ -19,6 +19,7 @@
 #ifndef TVM_NODE_NDARRAY_HASH_EQUAL_H_
 #define TVM_NODE_NDARRAY_HASH_EQUAL_H_
 
+#include <tvm/node/structural_hash.h>
 #include <tvm/runtime/ndarray.h>
 
 namespace tvm {
@@ -26,6 +27,12 @@ namespace tvm {
 class SEqualReducer;
 class SHashReducer;
 
+/*! \brief A custom hash handler that ignores NDArray raw data. */
+class SHashHandlerIgnoreNDArray : public SHashHandlerDefault {
+ protected:
+  void DispatchSHash(const ObjectRef& object, bool map_free_vars) override;
+};
+
 /*!
  * \brief Test two NDArrays for equality.
  * \param lhs The left operand.
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index 6cf796d344..9320fc320e 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -297,6 +297,16 @@ uint64_t StructuralHash::operator()(const ObjectRef& 
object) const {
   return SHashHandlerDefault().Hash(object, false);
 }
 
+void SHashHandlerIgnoreNDArray::DispatchSHash(const ObjectRef& object, bool 
map_free_vars) {
+  ICHECK(object.defined());
+  if (auto ndarray = object.as<runtime::NDArray::Container>()) {
+    SHashReducer hash_reduce(this, map_free_vars);
+    NDArrayHash(ndarray, &hash_reduce, false);
+  } else {
+    SHashHandlerDefault::DispatchSHash(object, map_free_vars);
+  }
+}
+
 // SEQualReduce traits for runtime containers.
 struct StringObjTrait {
   static constexpr const std::nullptr_t VisitAttrs = nullptr;
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index 1865834a73..fe9e9bf8a5 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -36,6 +36,8 @@
 #include <unordered_map>
 #include <vector>
 
+#include "../../node/ndarray_hash_equal.h"
+
 // Block builder have three categories of logics that are interdependent with 
each other.
 //
 // The logics are somewhat interdependent with each other.
@@ -374,11 +376,23 @@ class BlockBuilderImpl : public BlockBuilderNode {
     return name_supply_->FreshName(prefix, /*add_prefix*/ false, 
/*add_underscore*/ false);
   }
 
+  /*! \brief A custom structural hashing that ignores NDArray raw data. */
+  class StructuralHashIgnoreNDarray : public BaseValueHash {
+   public:
+    using BaseValueHash::operator();
+
+    uint64_t operator()(const ObjectRef& key) const {
+      return SHashHandlerIgnoreNDArray().Hash(key, false);
+    }
+  };
+
   /*!
    * \brief A hashmap to store the mapping of Relax functions and TIR PrimFuncs
    * in context_mod to their GlobalVar to avoid generating duplicated 
functions.
+   * We use a custom hash to avoid hashing constants that may be bound to each 
BaseFunc.
    */
-  std::unique_ptr<std::unordered_map<BaseFunc, GlobalVar, StructuralHash, 
StructuralEqual>>
+  std::unique_ptr<
+      std::unordered_map<BaseFunc, GlobalVar, StructuralHashIgnoreNDarray, 
StructuralEqual>>
       ctx_func_dedup_map_ = nullptr;
 
   /*!
@@ -387,7 +401,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
   void LazyInitCtxFuncDedupMap() {
     if (ctx_func_dedup_map_ != nullptr) return;
     ctx_func_dedup_map_ = std::make_unique<
-        std::unordered_map<BaseFunc, GlobalVar, StructuralHash, 
StructuralEqual>>();
+        std::unordered_map<BaseFunc, GlobalVar, StructuralHashIgnoreNDarray, 
StructuralEqual>>();
     for (const auto& kv : context_mod_->functions) {
       const GlobalVar gv = kv.first;
       const BaseFunc func = kv.second;

Reply via email to