This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 6c662eb631 [Unity] Use custom hash in `BlockBuilder` to avoid hashing
large constants (#14675)
6c662eb631 is described below
commit 6c662eb631d5c48d66873944920d3b20a31aaf91
Author: masahi <[email protected]>
AuthorDate: Thu Apr 20 21:14:09 2023 +0900
[Unity] Use custom hash in `BlockBuilder` to avoid hashing large constants
(#14675)
* use custom hash to avoid hashing constants
* add comments
* refactor
* format
---
src/meta_schedule/module_equality.cc | 13 -------------
src/node/ndarray_hash_equal.h | 7 +++++++
src/node/structural_hash.cc | 10 ++++++++++
src/relax/ir/block_builder.cc | 18 ++++++++++++++++--
4 files changed, 33 insertions(+), 15 deletions(-)
diff --git a/src/meta_schedule/module_equality.cc
b/src/meta_schedule/module_equality.cc
index 0997aab9b6..b4f37779dd 100644
--- a/src/meta_schedule/module_equality.cc
+++ b/src/meta_schedule/module_equality.cc
@@ -53,19 +53,6 @@ class SEqualHandlerIgnoreNDArray : public
SEqualHandlerDefault {
}
};
-class SHashHandlerIgnoreNDArray : public SHashHandlerDefault {
- protected:
- void DispatchSHash(const ObjectRef& object, bool map_free_vars) override {
- ICHECK(object.defined());
- if (auto ndarray = object.as<runtime::NDArray::Container>()) {
- SHashReducer hash_reduce(this, map_free_vars);
- NDArrayHash(ndarray, &hash_reduce, false);
- } else {
- SHashHandlerDefault::DispatchSHash(object, map_free_vars);
- }
- }
-};
-
class ModuleEqualityIgnoreNDArray : public ModuleEquality {
public:
size_t Hash(IRModule mod) const { return
SHashHandlerIgnoreNDArray().Hash(mod, false); }
diff --git a/src/node/ndarray_hash_equal.h b/src/node/ndarray_hash_equal.h
index d674018fbd..b5639f524b 100644
--- a/src/node/ndarray_hash_equal.h
+++ b/src/node/ndarray_hash_equal.h
@@ -19,6 +19,7 @@
#ifndef TVM_NODE_NDARRAY_HASH_EQUAL_H_
#define TVM_NODE_NDARRAY_HASH_EQUAL_H_
+#include <tvm/node/structural_hash.h>
#include <tvm/runtime/ndarray.h>
namespace tvm {
@@ -26,6 +27,12 @@ namespace tvm {
class SEqualReducer;
class SHashReducer;
+/*! \brief A custom hash handler that ignores NDArray raw data. */
+class SHashHandlerIgnoreNDArray : public SHashHandlerDefault {
+ protected:
+ void DispatchSHash(const ObjectRef& object, bool map_free_vars) override;
+};
+
/*!
* \brief Test two NDArrays for equality.
* \param lhs The left operand.
diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc
index 6cf796d344..9320fc320e 100644
--- a/src/node/structural_hash.cc
+++ b/src/node/structural_hash.cc
@@ -297,6 +297,16 @@ uint64_t StructuralHash::operator()(const ObjectRef&
object) const {
return SHashHandlerDefault().Hash(object, false);
}
+void SHashHandlerIgnoreNDArray::DispatchSHash(const ObjectRef& object, bool
map_free_vars) {
+ ICHECK(object.defined());
+ if (auto ndarray = object.as<runtime::NDArray::Container>()) {
+ SHashReducer hash_reduce(this, map_free_vars);
+ NDArrayHash(ndarray, &hash_reduce, false);
+ } else {
+ SHashHandlerDefault::DispatchSHash(object, map_free_vars);
+ }
+}
+
// SEQualReduce traits for runtime containers.
struct StringObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index 1865834a73..fe9e9bf8a5 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -36,6 +36,8 @@
#include <unordered_map>
#include <vector>
+#include "../../node/ndarray_hash_equal.h"
+
// Block builder have three categories of logics that are interdependent with
each other.
//
// The logics are somewhat interdependent with each other.
@@ -374,11 +376,23 @@ class BlockBuilderImpl : public BlockBuilderNode {
return name_supply_->FreshName(prefix, /*add_prefix*/ false,
/*add_underscore*/ false);
}
+ /*! \brief A custom structural hashing that ignores NDArray raw data. */
+ class StructuralHashIgnoreNDarray : public BaseValueHash {
+ public:
+ using BaseValueHash::operator();
+
+ uint64_t operator()(const ObjectRef& key) const {
+ return SHashHandlerIgnoreNDArray().Hash(key, false);
+ }
+ };
+
/*!
* \brief A hashmap to store the mapping of Relax functions and TIR PrimFuncs
* in context_mod to their GlobalVar to avoid generating duplicated
functions.
+ * We use a custom hash to avoid hashing constants that may be bound to each
BaseFunc.
*/
- std::unique_ptr<std::unordered_map<BaseFunc, GlobalVar, StructuralHash,
StructuralEqual>>
+ std::unique_ptr<
+ std::unordered_map<BaseFunc, GlobalVar, StructuralHashIgnoreNDarray,
StructuralEqual>>
ctx_func_dedup_map_ = nullptr;
/*!
@@ -387,7 +401,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
void LazyInitCtxFuncDedupMap() {
if (ctx_func_dedup_map_ != nullptr) return;
ctx_func_dedup_map_ = std::make_unique<
- std::unordered_map<BaseFunc, GlobalVar, StructuralHash,
StructuralEqual>>();
+ std::unordered_map<BaseFunc, GlobalVar, StructuralHashIgnoreNDarray,
StructuralEqual>>();
for (const auto& kv : context_mod_->functions) {
const GlobalVar gv = kv.first;
const BaseFunc func = kv.second;