This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new a2edd01 relay::StructuralHash to tvm::StructuralHash (#5166)
a2edd01 is described below
commit a2edd01b22c6db475394c386b4226f90413bd0e9
Author: Zhi <[email protected]>
AuthorDate: Sun Mar 29 13:03:25 2020 -0700
relay::StructuralHash to tvm::StructuralHash (#5166)
---
include/tvm/relay/analysis.h | 22 --
python/tvm/relay/analysis/analysis.py | 26 +-
python/tvm/relay/frontend/tensorflow.py | 2 +-
python/tvm/relay/testing/py_converter.py | 2 +-
src/relay/analysis/extract_fused_functions.cc | 3 +-
src/relay/backend/compile_engine.h | 3 +-
src/relay/backend/vm/lambda_lift.cc | 3 +-
src/relay/ir/hash.cc | 437 --------------------------
tests/python/relay/test_pass_qnn_legalize.py | 3 +-
9 files changed, 11 insertions(+), 490 deletions(-)
diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h
index 51eae5a..e04b4e6 100644
--- a/include/tvm/relay/analysis.h
+++ b/include/tvm/relay/analysis.h
@@ -225,28 +225,6 @@ TVM_DLL Map<Expr, Integer>
CollectDeviceAnnotationOps(const Expr& expr);
*/
TVM_DLL Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod);
-/*! \brief A hashing structure in the style of std::hash. */
-struct StructuralHash {
- /*! \brief Hash a Relay type.
- *
- * Implements structural hashing of a Relay type.
- *
- * \param type the type to hash.
- *
- * \return the hash value.
- */
- size_t operator()(const Type& type) const;
- /*! \brief Hash a Relay expression.
- *
- * Implements structural hashing of a Relay expression.
- *
- * \param expr the expression to hash.
- *
- * \return the hash value.
- */
- size_t operator()(const Expr& expr) const;
-};
-
} // namespace relay
} // namespace tvm
diff --git a/python/tvm/relay/analysis/analysis.py
b/python/tvm/relay/analysis/analysis.py
index b09a40b..21f3edf 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -20,11 +20,10 @@
This file contains the set of passes for Relay, which exposes an interface for
configuring the passes and scripting them in Python.
"""
-from tvm.ir import RelayExpr, IRModule
+from tvm.ir import IRModule
from . import _ffi_api
from .feature import Feature
-from ..ty import Type
def post_order_visit(expr, fvisit):
@@ -314,29 +313,6 @@ def detect_feature(a, b=None):
return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)}
-def structural_hash(value):
- """Hash a Relay expression structurally.
-
- Parameters
- ----------
- expr : Union[tvm.relay.Expr, tvm.relay.Type]
- The expression to hash.
-
- Returns
- -------
- result : int
- The hash value
- """
- if isinstance(value, RelayExpr):
- return int(_ffi_api._expr_hash(value))
- elif isinstance(value, Type):
- return int(_ffi_api._type_hash(value))
- else:
- msg = ("found value of type {0} expected" +
- "relay.Expr or relay.Type").format(type(value))
- raise TypeError(msg)
-
-
def extract_fused_functions(mod):
"""Pass to extract IRModule of only fused primitive functions.
diff --git a/python/tvm/relay/frontend/tensorflow.py
b/python/tvm/relay/frontend/tensorflow.py
index d0b90e5..56aa1d6 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -27,7 +27,7 @@ import tvm
from tvm.ir import IRModule
from tvm.relay.prelude import Prelude
-from tvm.relay.analysis import structural_hash as s_hash
+from tvm.ir import structural_hash as s_hash
from .. import analysis
from .. import expr as _expr
diff --git a/python/tvm/relay/testing/py_converter.py
b/python/tvm/relay/testing/py_converter.py
index e850000..eec5e16 100644
--- a/python/tvm/relay/testing/py_converter.py
+++ b/python/tvm/relay/testing/py_converter.py
@@ -238,7 +238,7 @@ class PythonConverter(ExprFunctor):
# compile the function and register globally
cc_key = compile_engine.CCacheKey(op, self.tgt)
- func_hash = relay.analysis.structural_hash(op)
+ func_hash = tvm.ir.structural_hash(op)
op_name = '_lowered_op_{}'.format(func_hash)
if not tvm.get_global_func(op_name, allow_missing=True):
jitted = self.engine.jit(cc_key, self.tgt)
diff --git a/src/relay/analysis/extract_fused_functions.cc
b/src/relay/analysis/extract_fused_functions.cc
index 8cb517f..ff3756c 100644
--- a/src/relay/analysis/extract_fused_functions.cc
+++ b/src/relay/analysis/extract_fused_functions.cc
@@ -21,6 +21,7 @@
* \file extract_fused_functions.cc
* \brief Apply fusion and extract fused primitive functions from an IRModule
*/
+#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
@@ -55,7 +56,7 @@ class FusedFunctionExtractorWrapper : private ExprVisitor {
if (n->HasNonzeroAttr(attr::kPrimitive)) {
// Add function to functions, keyed by function hash string
Function func = Function(n->params, n->body, n->ret_type,
n->type_params, n->attrs);
- size_t hash_ = StructuralHash()(func);
+ size_t hash_ = tvm::StructuralHash()(func);
this->functions.Set(std::to_string(hash_), func);
}
diff --git a/src/relay/backend/compile_engine.h
b/src/relay/backend/compile_engine.h
index eec2bd3..9bd6a4e 100644
--- a/src/relay/backend/compile_engine.h
+++ b/src/relay/backend/compile_engine.h
@@ -26,6 +26,7 @@
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
@@ -258,7 +259,7 @@ bool IsDynamic(const Type& ty);
inline size_t CCacheKeyNode::Hash() const {
if (hash_ != 0) return hash_;
// do structral hash, avoid 0.
- hash_ = StructuralHash()(this->source_func);
+ hash_ = tvm::StructuralHash()(this->source_func);
hash_ = dmlc::HashCombine(
hash_, std::hash<std::string>()(target->str()));
if (hash_ == 0) hash_ = 1;
diff --git a/src/relay/backend/vm/lambda_lift.cc
b/src/relay/backend/vm/lambda_lift.cc
index 398760f..80745e1 100644
--- a/src/relay/backend/vm/lambda_lift.cc
+++ b/src/relay/backend/vm/lambda_lift.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/support/logging.h>
@@ -39,7 +40,7 @@ namespace relay {
namespace vm {
inline std::string GenerateName(const Function& func) {
- size_t hash = StructuralHash()(func);
+ size_t hash = tvm::StructuralHash()(func);
return std::string("lifted_name") + std::to_string(hash);
}
diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc
deleted file mode 100644
index ce15e2a..0000000
--- a/src/relay/ir/hash.cc
+++ /dev/null
@@ -1,437 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file src/relay/ir/hash.cc
- * \brief Hash functions for Relay types and expressions.
- */
-#include <tvm/ir/type_functor.h>
-#include <tvm/tir/ir_pass.h>
-#include <tvm/relay/expr_functor.h>
-#include <tvm/relay/pattern_functor.h>
-#include <tvm/runtime/ndarray.h>
-#include <tvm/relay/analysis.h>
-#include <tvm/ir/attrs.h>
-#include "../../ir/attr_functor.h"
-
-namespace tvm {
-namespace relay {
-
-// Hash handler for Relay.
-class RelayHashHandler:
- public AttrsHashHandler,
- public TypeFunctor<size_t(const Type&)>,
- public ExprFunctor<size_t(const Expr&)>,
- public PatternFunctor<size_t(const Pattern&)> {
- public:
- explicit RelayHashHandler() {}
-
- /*!
- * Compute hash of a node.
- * \param ref The node to hash.
- * \return the hash value.
- */
- size_t Hash(const ObjectRef& ref) {
- if (!ref.defined()) return ObjectHash()(ref);
-
- if (ref->IsInstance<TypeNode>()) {
- return TypeHash(Downcast<Type>(ref));
- }
- if (ref->IsInstance<ExprNode>()) {
- return ExprHash(Downcast<Expr>(ref));
- }
- return AttrHash(ref);
- }
-
- /*!
- * Compute hash of the attributes.
- * \param ref The attributes.
- * \return the hash value
- */
- size_t AttrHash(const ObjectRef& ref) {
- if (!ref.defined()) {
- return ObjectHash()(ref);
- }
- return AttrsHashHandler::Hash(ref);
- }
- /*!
- * Compute hash of a Relay type.
- * \param ref The type to hash.
- * \param rhs The right hand operand.
- * \return the hash value.
- */
- size_t TypeHash(const Type& type) {
- if (!type.defined()) {
- return ObjectHash()(type);
- }
- auto found = hash_map_.find(type);
- if (found != hash_map_.end()) {
- return found->second;
- } else {
- auto hash = this->VisitType(type);
- hash_map_.insert({type, hash});
- return hash;
- }
- }
- /*!
- * Compute the hash of an expression.
- *
- * \note We run graph structural equality checking when comparing two Exprs.
- * This means that AlphaEqualHandler can only be used once for each pair.
- * The equality checker checks data-flow equvalence of the Expr DAG.
- * This function also runs faster as it memomizes equal_map.
- *
- * \param expr The expression to hash.
- * \return the hash value.
- */
- size_t ExprHash(const Expr& expr) {
- if (!expr.defined()) {
- return ObjectHash()(expr);
- }
- auto found = hash_map_.find(expr);
- if (found != hash_map_.end()) {
- return found->second;
- } else {
- auto hash = this->VisitExpr(expr);
- hash_map_.insert({expr, hash});
- return hash;
- }
- }
-
- protected:
- /*!
- * \brief Hash a DataType.
- * \param dtype The dtype to hash.
- * \return the hash value.
- */
- size_t DataTypeHash(const DataType& dtype) {
- return ::tvm::AttrsHash()(dtype);
- }
-
- using AttrsHashHandler::VisitAttr_;
- size_t VisitAttr_(const tvm::tir::VarNode* var) final {
- size_t hash = std::hash<std::string>()(VarNode::_type_key);
- auto it = hash_map_.find(GetRef<tvm::tir::Var>(var));
- if (it != hash_map_.end()) {
- return it->second;
- }
- return Combine(hash, std::hash<std::string>()(var->name_hint));
- }
-
- // Type hashing
- size_t VisitType_(const TensorTypeNode* tensor_type) final {
- size_t hash = std::hash<std::string>()(TensorTypeNode::_type_key);
- hash = Combine(hash, DataTypeHash(tensor_type->dtype));
- hash = Combine(hash, Hash(tensor_type->shape));
- return hash;
- }
-
- size_t VisitType_(const IncompleteTypeNode* incomplete) final {
- size_t hash = std::hash<std::string>()(IncompleteTypeNode::_type_key);
- return Combine(hash, std::hash<int>()(incomplete->kind));
- }
-
- size_t VisitType_(const TypeVarNode* tyvar) final {
- /*
- TypeVar/Var/Variable have two locations where they are hashed:
-
- The declaration site of a function, let, or function type.
- The first occurence in the term.
-
- We will only reach this code if the TypeVar itself is unbound, we assign
- a free variable index to it, meaning this hashing function implements
- structural equality for both open (i.e graph equality) and closed terms
- (i.e alpha_equality).
- */
- return BindVar(GetRef<TypeVar>(tyvar));
- }
-
- size_t VisitType_(const FuncTypeNode* func_type) final {
- size_t hash = std::hash<std::string>()(FuncTypeNode::_type_key);
-
- for (auto type_param : func_type->type_params) {
- hash = Combine(hash, BindVar(type_param));
- }
-
- for (auto arg : func_type->arg_types) {
- hash = Combine(hash, TypeHash(arg));
- }
-
- hash = Combine(hash, TypeHash(func_type->ret_type));
- for (auto cs : func_type->type_constraints) {
- hash = Combine(hash, TypeHash(cs));
- }
-
- return hash;
- }
-
- size_t VisitType_(const TypeRelationNode* type_rel) final {
- size_t hash = std::hash<std::string>()(TypeRelationNode::_type_key);
- hash = Combine(hash, std::hash<std::string>()(type_rel->func->name));
- hash = Combine(hash, AttrHash(type_rel->attrs));
-
- for (auto arg : type_rel->args) {
- hash = Combine(hash, TypeHash(arg));
- }
-
- return hash;
- }
-
- size_t VisitType_(const TupleTypeNode* tuple_type) final {
- size_t hash = std::hash<std::string>()(TupleTypeNode::_type_key);
- for (size_t i = 0; i < tuple_type->fields.size(); i++) {
- hash = Combine(hash, TypeHash(tuple_type->fields[i]));
- }
- return hash;
- }
-
- size_t VisitType_(const RelayRefTypeNode* rtn) final {
- size_t hash = std::hash<std::string>()(RelayRefTypeNode::_type_key);
- hash = Combine(hash, TypeHash(rtn->value));
- return hash;
- }
-
- // Expr hashing.
- size_t NDArrayHash(const runtime::NDArray& array) {
- size_t hash = std::hash<uint8_t>()(array->dtype.code);
- hash = Combine(hash, std::hash<uint8_t>()(array->dtype.bits));
- hash = Combine(hash, std::hash<uint16_t>()(array->dtype.lanes));
- CHECK_EQ(array->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
- size_t data_size = runtime::GetDataSize(*array.operator->());
- uint8_t * data = reinterpret_cast<uint8_t*>(array->data);
- for (size_t i = 0; i < data_size; i++) {
- hash = Combine(hash, std::hash<uint8_t>()(data[i]));
- }
- return hash;
- }
-
- size_t BindVar(const ObjectRef& var) {
- size_t hash = std::hash<int>()(var_counter++);
- CHECK_EQ(hash_map_.count(var), 0);
- if (auto var_node = var.as<VarNode>()) {
- hash = Combine(hash, TypeHash(var_node->type_annotation));
- }
- hash_map_[var] = hash;
- return hash;
- }
-
- size_t VisitExpr_(const VarNode* var) final {
- // hash free variable
- size_t name_hash = std::hash<const Object*>()(var->vid.get());
- return Combine(name_hash, TypeHash(var->type_annotation));
- }
-
- size_t VisitExpr_(const GlobalVarNode* global) final {
- return std::hash<std::string>()(global->name_hint);
- }
-
- size_t VisitExpr_(const TupleNode* tuple) final {
- size_t hash = std::hash<std::string>()(TupleNode::_type_key);
- for (size_t i = 0; i < tuple->fields.size(); i++) {
- hash = Combine(hash, ExprHash(tuple->fields[i]));
- }
- return hash;
- }
-
- size_t VisitExpr_(const FunctionNode* func) final {
- size_t hash = std::hash<std::string>()(FunctionNode::_type_key);
- for (auto type_param : func->type_params) {
- hash = Combine(hash, BindVar(type_param));
- }
-
- for (auto param : func->params) {
- hash = Combine(hash, BindVar(param));
- }
-
- hash = Combine(hash, TypeHash(func->ret_type));
- hash = Combine(hash, ExprHash(func->body));
-
- hash = Combine(hash, AttrHash(func->attrs));
-
- return hash;
- }
-
- size_t VisitExpr_(const CallNode* call) final {
- size_t hash = std::hash<std::string>()(CallNode::_type_key);
- hash = Combine(hash, ExprHash(call->op));
-
- for (auto arg : call->args) {
- hash = Combine(hash, ExprHash(arg));
- }
-
- for (auto t : call->type_args) {
- CHECK(t.defined());
- hash = Combine(hash, TypeHash(t));
- }
-
- hash = Combine(hash, AttrHash(call->attrs));
-
- return hash;
- }
-
- size_t VisitExpr_(const LetNode* let) final {
- size_t hash = std::hash<std::string>()(LetNode::_type_key);
- hash = Combine(hash, BindVar(let->var));
- hash = Combine(hash, ExprHash(let->value));
- hash = Combine(hash, ExprHash(let->body));
- return hash;
- }
-
- size_t VisitExpr_(const IfNode* ite) final {
- size_t key = std::hash<std::string>()(IfNode::_type_key);
- size_t hash = key;
- hash = Combine(hash, ExprHash(ite->cond));
- hash = Combine(hash, ExprHash(ite->true_branch));
- hash = Combine(hash, ExprHash(ite->false_branch));
- return hash;
- }
-
- size_t VisitExpr_(const OpNode* op) final {
- return ObjectHash()(GetRef<Op>(op));
- }
-
- size_t VisitExpr_(const ConstantNode* rconst) final {
- return NDArrayHash(rconst->data);
- }
-
- size_t VisitExpr_(const TupleGetItemNode* get_item) final {
- size_t hash = std::hash<std::string>()(TupleGetItemNode::_type_key);
- hash = Combine(hash, ExprHash(get_item->tuple));
- hash = Combine(hash, std::hash<int>()(get_item->index));
- return hash;
- }
-
- size_t VisitExpr_(const RefCreateNode* rn) final {
- size_t hash = std::hash<std::string>()(RefCreateNode::_type_key);
- hash = Combine(hash, ExprHash(rn->value));
- return hash;
- }
-
- size_t VisitExpr_(const RefReadNode* rn) final {
- size_t hash = std::hash<std::string>()(RefReadNode::_type_key);
- hash = Combine(hash, ExprHash(rn->ref));
- return hash;
- }
-
- size_t VisitExpr_(const RefWriteNode* rn) final {
- size_t hash = std::hash<std::string>()(RefWriteNode::_type_key);
- hash = Combine(hash, ExprHash(rn->ref));
- hash = Combine(hash, ExprHash(rn->value));
- return hash;
- }
-
- size_t VisitExpr_(const MatchNode* mn) final {
- size_t hash = std::hash<std::string>()(MatchNode::_type_key);
- hash = Combine(hash, ExprHash(mn->data));
- for (const auto& c : mn->clauses) {
- hash = Combine(hash, PatternHash(c->lhs));
- hash = Combine(hash, ExprHash(c->rhs));
- }
- hash = Combine(hash, std::hash<bool>()(mn->complete));
- return hash;
- }
-
- size_t VisitExpr_(const ConstructorNode* cn) final {
- size_t hash = std::hash<std::string>()(ConstructorNode::_type_key);
- hash = Combine(hash, std::hash<std::string>()(cn->name_hint));
- return hash;
- }
-
- size_t VisitType_(const TypeCallNode* tcn) final {
- size_t hash = std::hash<std::string>()(TypeCallNode::_type_key);
- hash = Combine(hash, TypeHash(tcn->func));
- for (const auto& t : tcn->args) {
- hash = Combine(hash, TypeHash(t));
- }
- return hash;
- }
-
- size_t VisitType_(const TypeDataNode* tdn) final {
- size_t hash = std::hash<std::string>()(TypeDataNode::_type_key);
- hash = Combine(hash, TypeHash(tdn->header));
- for (const auto& tv : tdn->type_vars) {
- hash = Combine(hash, TypeHash(tv));
- }
- for (const auto& cn : tdn->constructors) {
- hash = Combine(hash, ExprHash(cn));
- }
- return hash;
- }
-
- size_t VisitType_(const GlobalTypeVarNode* tvn) final {
- return BindVar(GetRef<GlobalTypeVar>(tvn));
- }
-
- size_t PatternHash(const Pattern& p) {
- return VisitPattern(p);
- }
-
- size_t VisitPattern_(const PatternConstructorNode* pcn) final {
- size_t hash = std::hash<std::string>()(PatternConstructorNode::_type_key);
- hash = Combine(hash, ExprHash(pcn->constructor));
- for (const auto& p : pcn->patterns) {
- hash = Combine(hash, PatternHash(p));
- }
- return hash;
- }
-
- size_t VisitPattern_(const PatternTupleNode* ptn) final {
- size_t hash = std::hash<std::string>()(PatternTupleNode::_type_key);
- for (const auto& p : ptn->patterns) {
- hash = Combine(hash, PatternHash(p));
- }
- return hash;
- }
-
- size_t VisitPattern_(const PatternVarNode* pvn) final {
- size_t hash = std::hash<std::string>()(PatternVarNode::_type_key);
- hash = Combine(hash, BindVar(pvn->var));
- return hash;
- }
-
- size_t VisitPattern_(const PatternWildcardNode* pwn) final {
- size_t hash = std::hash<std::string>()(PatternWildcardNode::_type_key);
- return hash;
- }
- private:
- // renaming of NodeRef to indicate two nodes equals to each other
- std::unordered_map<ObjectRef, size_t, ObjectHash, ObjectEqual> hash_map_;
- int var_counter = 0;
-};
-
-size_t StructuralHash::operator()(const Type& type) const {
- return RelayHashHandler().TypeHash(type);
-}
-
-size_t StructuralHash::operator()(const Expr& expr) const {
- return RelayHashHandler().ExprHash(expr);
-}
-
-TVM_REGISTER_GLOBAL("relay.analysis._expr_hash")
-.set_body_typed([](ObjectRef ref) {
- return static_cast<int64_t>(RelayHashHandler().Hash(ref));
-});
-
-TVM_REGISTER_GLOBAL("relay.analysis._type_hash")
-.set_body_typed([](Type type) {
- return static_cast<int64_t>(RelayHashHandler().TypeHash(type));
-});
-
-} // namespace relay
-} // namespace tvm
diff --git a/tests/python/relay/test_pass_qnn_legalize.py
b/tests/python/relay/test_pass_qnn_legalize.py
index e7980e7..c291c4e 100644
--- a/tests/python/relay/test_pass_qnn_legalize.py
+++ b/tests/python/relay/test_pass_qnn_legalize.py
@@ -31,7 +31,8 @@ def alpha_equal(x, y):
"""
x = x['main']
y = y['main']
- return tvm.ir.structural_equal(x, y) and analysis.structural_hash(x) ==
analysis.structural_hash(y)
+ return tvm.ir.structural_equal(x, y) and \
+ tvm.ir.structural_hash(x) == tvm.ir.structural_hash(y)
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]