This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 6536b35 remove AttrsEqual and AttrsHash related code (#5169)
6536b35 is described below
commit 6536b356fcd74ee3a2544900c36dfe2340290720
Author: Zhi <[email protected]>
AuthorDate: Sun Mar 29 18:57:27 2020 -0700
remove AttrsEqual and AttrsHash related code (#5169)
---
include/tvm/ir/attrs.h | 171 +------------
src/ir/attr_functor.h | 89 -------
src/ir/attrs.cc | 278 ---------------------
src/node/structural_equal.cc | 1 +
src/relay/transforms/combine_parallel_conv2d.cc | 4 +-
src/relay/transforms/combine_parallel_dense.cc | 2 +-
src/relay/transforms/combine_parallel_op.cc | 3 +-
src/relay/transforms/combine_parallel_op_batch.cc | 4 +-
src/relay/transforms/eliminate_common_subexpr.cc | 2 +-
src/relay/transforms/fold_scale_axis.cc | 4 +-
src/relay/transforms/fuse_ops.cc | 2 +-
src/relay/transforms/pattern_util.h | 2 +-
src/tir/pass/ffi_api.cc | 12 -
tests/python/relay/test_ir_nodes.py | 1 -
tests/python/unittest/test_ir_attrs.py | 9 +-
.../unittest/test_tir_pass_attrs_hash_equal.py | 18 +-
16 files changed, 29 insertions(+), 573 deletions(-)
diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index fbca3bb..0fc832e 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -46,6 +46,8 @@
#include <dmlc/common.h>
#include <tvm/ir/expr.h>
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
#include <tvm/runtime/packed_func.h>
#include <unordered_map>
@@ -131,95 +133,6 @@ class AttrFieldInfo : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
};
-class AttrsHashHandler;
-class AttrsEqualHandler;
-/*!
- * \brief Content-aware Equality comparator for attrs.
- *
- * This comparator will recursively deep compare the following Attributes.
- *
- * - IntImm, UIntImm, FloatImm, StringImm
- * - Any subclass of BaseAttrsNode
- * - Array of Attributes.
- * - Map from string to Attributes.
- */
-class AttrsEqual {
- public:
- bool operator()(const double& lhs, const double& rhs) const {
- // fuzzy float pt comparison
- constexpr double atol = 1e-9;
- if (lhs == rhs) return true;
- double diff = lhs - rhs;
- return diff > -atol && diff < atol;
- }
-
- bool operator()(const int64_t& lhs, const int64_t& rhs) const {
- return lhs == rhs;
- }
- bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
- return lhs == rhs;
- }
- bool operator()(const int& lhs, const int& rhs) const {
- return lhs == rhs;
- }
- bool operator()(const bool& lhs, const bool& rhs) const {
- return lhs == rhs;
- }
- bool operator()(const std::string& lhs, const std::string& rhs) const {
- return lhs == rhs;
- }
- bool operator()(const DataType& lhs, const DataType& rhs) const {
- return lhs == rhs;
- }
- // node comparator
- TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
-
- protected:
- friend class AttrsEqualHandler;
- /*! \brief internal handle. */
- AttrsEqualHandler* handler_{nullptr};
-};
-
-/*!
- * \brief Content-aware hash function.
- *
- * This hash functor will recursively hash the content of the Attributes.
- * It is guaranteed that if AttrsEqual(a, b) == true, then AttrsHash(a) ==
AttrsHash(b);
- */
-class AttrsHash {
- public:
- size_t operator()(const double& value) const {
- return std::hash<double>()(value);
- }
- size_t operator()(const int64_t& value) const {
- return std::hash<int64_t>()(value);
- }
- size_t operator()(const uint64_t& value) const {
- return std::hash<uint64_t>()(value);
- }
- size_t operator()(const int& value) const {
- return std::hash<int>()(value);
- }
- size_t operator()(const bool& value) const {
- return std::hash<bool>()(value);
- }
- size_t operator()(const std::string& value) const {
- return std::hash<std::string>()(value);
- }
- size_t operator()(const DataType& value) const {
- return std::hash<int>()(
- static_cast<int>(value.code()) |
- (static_cast<int>(value.bits()) << 8) |
- (static_cast<int>(value.lanes()) << 16));
- }
- TVM_DLL size_t operator()(const ObjectRef& value) const;
-
- private:
- friend class AttrsHashHandler;
- /*! \brief internal handle. */
- AttrsHashHandler* handler_{nullptr};
-};
-
/*!
* \brief Base class of all attribute class
* \note Do not subclass AttrBaseNode directly,
@@ -266,20 +179,6 @@ class BaseAttrsNode : public Object {
* \note This function throws when the required field is not present.
*/
TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool
allow_unknown = false) = 0;
- /*!
- * \brief Whether this attribute's content equals to another node.
- * \param other The pointer to another node.
- * \param equal The equal comparator
- * \return The comparison result.
- */
- TVM_DLL virtual bool ContentEqual(
- const Object* other, AttrsEqual equal) const = 0;
- /*!
- * \brief Content aware hash.
- * \param hasher The hasher to run the hash.
- * \return the hash result.
- */
- TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
@@ -320,8 +219,6 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown)
final;
Array<AttrFieldInfo> ListFieldInfo() const final;
- bool ContentEqual(const Object* other, AttrsEqual equal) const final;
- size_t ContentHash(AttrsHash hasher) const final;
// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
@@ -386,34 +283,6 @@ class AttrNormalVisitor {
AttrVisitor* visitor_;
};
-// Wrapper for normal visitor.
-class AttrsEqualVisitor {
- public:
- bool result_{true};
- // constructor
- AttrsEqualVisitor(const Object* lhs, const Object* rhs, const AttrsEqual&
equal)
- : lhs_(lhs), rhs_(rhs), equal_(equal) {
- }
- template<typename T>
- AttrNopEntry operator()(const char* key, T* lhs_value) {
- if (!result_) return AttrNopEntry();
- const T* rhs_value =
- reinterpret_cast<const T*>(
- reinterpret_cast<const char*>(rhs_) +
- (reinterpret_cast<const char*>(lhs_value) -
- reinterpret_cast<const char*>(lhs_)));
- if (!equal_(*lhs_value, *rhs_value)) {
- result_ = false;
- }
- return AttrNopEntry();
- }
-
- private:
- const Object* lhs_;
- const Object* rhs_;
- const AttrsEqual& equal_;
-};
-
class AttrsSEqualVisitor {
public:
bool result_{true};
@@ -441,23 +310,6 @@ class AttrsSEqualVisitor {
const SEqualReducer& equal_;
};
-class AttrsHashVisitor {
- public:
- explicit AttrsHashVisitor(const AttrsHash& hasher)
- : hasher_(hasher) {}
-
- size_t result_{0};
-
- template<typename T>
- AttrNopEntry operator()(const char* key, T* value) {
- result_ = dmlc::HashCombine(result_, hasher_(*value));
- return AttrNopEntry();
- }
-
- private:
- const AttrsHash& hasher_;
-};
-
class AttrsSHashVisitor {
public:
explicit AttrsSHashVisitor(const SHashReducer& hash_reducer)
@@ -760,7 +612,7 @@ struct AttrTriggerNonDefaultEntry {
return *this;
}
TSelf& set_default(const T& value) {
- if (AttrsEqual()(value, *data_)) {
+ if (tvm::StructuralEqual()(value, *data_)) {
trigger_ = false;
}
return *this;
@@ -890,23 +742,6 @@ class AttrsNode : public BaseAttrsNode {
return visitor.fields_;
}
- bool ContentEqual(const Object* other, AttrsEqual equal) const final {
- DerivedType* pself = self();
- if (pself == other) return true;
- if (other == nullptr) return false;
- if (pself->type_index() != other->type_index()) return false;
- ::tvm::detail::AttrsEqualVisitor visitor(pself, other, equal);
- self()->__VisitAttrs__(visitor);
- return visitor.result_;
- }
-
- size_t ContentHash(AttrsHash hasher) const final {
- ::tvm::detail::AttrsHashVisitor visitor(hasher);
- visitor.result_ = this->GetTypeKeyHash();
- self()->__VisitAttrs__(visitor);
- return visitor.result_;
- }
-
private:
DerivedType* self() const {
return const_cast<DerivedType*>(
diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h
index 9acc465..dbd5a4f 100644
--- a/src/ir/attr_functor.h
+++ b/src/ir/attr_functor.h
@@ -147,94 +147,5 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
}
};
-class AttrsEqualHandler :
- protected AttrFunctor<bool(const ObjectRef&, const ObjectRef&)> {
- public:
- /*!
- * \brief Check if lhs equals rhs
- * \param lhs The left operand.
- * \param rhs The right operand.
- */
- bool Equal(const ObjectRef& lhs, const ObjectRef& rhs);
-
- protected:
- bool VisitAttrDefault_(const Object* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::IntImmNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::FloatImmNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::StringImmNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::AddNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::SubNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::MulNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::DivNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::ModNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::FloorDivNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::FloorModNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::MinNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::MaxNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::GENode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::GTNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::LTNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::LENode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::EQNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::NENode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::AndNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::OrNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::NotNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::CastNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::CallNode* lhs, const ObjectRef& other) final;
- bool VisitAttr_(const tir::SelectNode* lhs, const ObjectRef& other) final;
-};
-
-class AttrsHashHandler :
- protected AttrFunctor<size_t(const ObjectRef&)> {
- public:
- /*!
- * \brief Get hash value of node
- * \param node The node to be hashed.
- */
- size_t Hash(const ObjectRef& node) {
- if (!node.defined()) return 0;
- return this->VisitAttr(node);
- }
-
- protected:
- size_t VisitAttrDefault_(const Object* lhs) final;
- size_t VisitAttr_(const tir::IntImmNode* lhs) final;
- size_t VisitAttr_(const tir::FloatImmNode* lhs) final;
- size_t VisitAttr_(const tir::StringImmNode* lhs) final;
- size_t VisitAttr_(const ArrayNode* lhs) final;
- size_t VisitAttr_(const StrMapNode* lhs) final;
- size_t VisitAttr_(const tir::AddNode* op) final;
- size_t VisitAttr_(const tir::SubNode* op) final;
- size_t VisitAttr_(const tir::MulNode* op) final;
- size_t VisitAttr_(const tir::DivNode* op) final;
- size_t VisitAttr_(const tir::ModNode* op) final;
- size_t VisitAttr_(const tir::FloorDivNode* op) final;
- size_t VisitAttr_(const tir::FloorModNode* op) final;
- size_t VisitAttr_(const tir::MinNode* op) final;
- size_t VisitAttr_(const tir::MaxNode* op) final;
- size_t VisitAttr_(const tir::GENode* op) final;
- size_t VisitAttr_(const tir::GTNode* op) final;
- size_t VisitAttr_(const tir::LENode* op) final;
- size_t VisitAttr_(const tir::LTNode* op) final;
- size_t VisitAttr_(const tir::EQNode* op) final;
- size_t VisitAttr_(const tir::NENode* op) final;
- size_t VisitAttr_(const tir::AndNode* op) final;
- size_t VisitAttr_(const tir::OrNode* op) final;
- size_t VisitAttr_(const tir::NotNode* op) final;
- size_t VisitAttr_(const tir::CastNode* op) final;
- size_t VisitAttr_(const tir::CallNode* op) final;
- size_t VisitAttr_(const tir::SelectNode* op) final;
- /*!
- * \brief alias of dmlc::HashCombine
- * \param lhs The first hash value.
- * \param rhs The second hash value.
- */
- static size_t Combine(size_t lhs, size_t rhs) {
- return dmlc::HashCombine(lhs, rhs);
- }
-};
} // namespace tvm
#endif // TVM_IR_ATTR_FUNCTOR_H_
diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc
index 868fec6..066b8f9 100644
--- a/src/ir/attrs.cc
+++ b/src/ir/attrs.cc
@@ -74,287 +74,9 @@ TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict")
return attrs->dict;
});
-
-using namespace tir;
-// Equal handler.
-bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
- if (lhs.same_as(rhs)) return true;
- if (!lhs.defined() && rhs.defined()) return false;
- if (!rhs.defined() && lhs.defined()) return false;
- return this->VisitAttr(lhs, rhs);
-}
-
-bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef&
other) {
- if (lhs->IsInstance<BaseAttrsNode>()) {
- AttrsEqual equal;
- equal.handler_ = this;
- return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(
- other.get(), equal);
- }
- return lhs == other.get();
-}
-
-bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef&
other) {
- if (const auto* rhs = other.as<IntImmNode>()) {
- return lhs->value == rhs->value;
- } else {
- return false;
- }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef&
other) {
- if (const auto* rhs = other.as<FloatImmNode>()) {
- return lhs->value == rhs->value;
- } else {
- return false;
- }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef&
other) {
- if (const auto* rhs = other.as<StringImmNode>()) {
- return lhs->value == rhs->value;
- } else {
- return false;
- }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef&
other) {
- if (const auto* rhs = other.as<ArrayNode>()) {
- if (rhs->data.size() != lhs->data.size()) return false;
- for (size_t i = 0; i < lhs->data.size(); ++i) {
- if (!Equal(lhs->data[i], rhs->data[i])) return false;
- }
- return true;
- } else {
- return false;
- }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef&
other) {
- if (const auto* rhs = other.as<StrMapNode>()) {
- if (rhs->data.size() != lhs->data.size()) return false;
- for (const auto& kv : lhs->data) {
- auto it = rhs->data.find(kv.first);
- if (it == rhs->data.end()) return false;
- if (!Equal(kv.second, it->second)) return false;
- }
- return true;
- } else {
- return false;
- }
-}
-
-#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \
- bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const ObjectRef&
other) { \
- if (const auto* rhs = other.as<NodeName>()) { \
- if (!Equal(lhs->a, rhs->a)) return false; \
- if (!Equal(lhs->b, rhs->b)) return false; \
- return true; \
- } else { \
- return false; \
- } \
- } \
-
-TVM_DEFINE_ATTRS_BINOP_EQUAL(AddNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDivNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorModNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(MaxNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(MinNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(GENode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(GTNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(LENode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(LTNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(EQNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(NENode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(AndNode);
-TVM_DEFINE_ATTRS_BINOP_EQUAL(OrNode);
-
-bool AttrsEqualHandler::VisitAttr_(const NotNode* lhs, const ObjectRef& other)
{
- if (const auto* rhs = other.as<NotNode>()) {
- return Equal(lhs->a, rhs->a);
- } else {
- return false;
- }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const CastNode* lhs, const ObjectRef&
other) {
- if (const auto* rhs = other.as<CastNode>()) {
- if (lhs->dtype != rhs->dtype) return false;
- return Equal(lhs->value, rhs->value);
- } else {
- return false;
- }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const CallNode* lhs, const ObjectRef&
other) {
- if (const auto* rhs = other.as<CallNode>()) {
- return
- lhs->name == rhs->name &&
- lhs->dtype == rhs->dtype &&
- lhs->call_type == rhs->call_type &&
- Equal(lhs->args, rhs->args);
- } else {
- return false;
- }
-}
-
-bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef&
other) {
- if (const auto* rhs = other.as<SelectNode>()) {
- return
- Equal(lhs->condition, rhs->condition) &&
- Equal(lhs->true_value, rhs->true_value) &&
- Equal(lhs->false_value, rhs->false_value);
- } else {
- return false;
- }
-}
-
-// Hash Handler.
-size_t AttrsHashHandler::VisitAttrDefault_(const Object* value) {
- if (value->IsInstance<BaseAttrsNode>()) {
- AttrsHash hasher;
- hasher.handler_ = this;
- return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher);
- } else {
- return ObjectHash()(GetRef<ObjectRef>(value));
- }
-}
-
-size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) {
- return std::hash<int64_t>()(op->value);
-}
-
-size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) {
- return std::hash<double>()(op->value);
-}
-
-size_t AttrsHashHandler::VisitAttr_(const StringImmNode* op) {
- return std::hash<std::string>()(op->value);
-}
-
-size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) {
- size_t result = op->data.size();
- for (size_t i = 0; i < op->data.size(); ++i) {
- result = Combine(result, this->Hash(op->data[i]));
- }
- return result;
-}
-
-size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
- using Entry = std::pair<std::string, ObjectRef>;
- std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
- std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
- return a.first < b.first;
- });
- size_t result = 0;
- for (const Entry& kv : data) {
- result = Combine(result, std::hash<std::string>()(kv.first));
- result = Combine(result, this->Hash(kv.second));
- }
- return result;
-}
-
-
-#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName) \
- size_t AttrsHashHandler::VisitAttr_(const NodeName* op) { \
- static size_t key = std::hash<std::string>()(NodeName::_type_key); \
- return Combine(key, Combine(Hash(op->a), Hash(op->b))); \
- } \
-
-TVM_DEFINE_ATTRS_BINOP_HASH(AddNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(SubNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(MulNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(DivNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(ModNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(MinNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(GENode);
-TVM_DEFINE_ATTRS_BINOP_HASH(GTNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(LENode);
-TVM_DEFINE_ATTRS_BINOP_HASH(LTNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(EQNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(NENode);
-TVM_DEFINE_ATTRS_BINOP_HASH(AndNode);
-TVM_DEFINE_ATTRS_BINOP_HASH(OrNode);
-
-size_t AttrsHashHandler::VisitAttr_(const NotNode* op) {
- static size_t key = std::hash<std::string>()(NotNode::_type_key);
- return Combine(key, Hash(op->a));
-}
-
-size_t AttrsHashHandler::VisitAttr_(const CastNode* op) {
- static size_t key = std::hash<std::string>()(CastNode::_type_key);
- AttrsHash hasher;
- size_t res = key;
- res = Combine(res, hasher(op->dtype));
- res = Combine(res, Hash(op->value));
- return res;
-}
-
-size_t AttrsHashHandler::VisitAttr_(const CallNode* op) {
- static size_t key = std::hash<std::string>()(CallNode::_type_key);
- AttrsHash hasher;
- size_t res = key;
- res = Combine(res, hasher(op->name));
- res = Combine(res, hasher(op->dtype));
- res = Combine(res, Hash(op->args));
- return res;
-}
-
-size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) {
- static size_t key = std::hash<std::string>()(SelectNode::_type_key);
- size_t res = key;
- res = Combine(res, Hash(op->condition));
- res = Combine(res, Hash(op->true_value));
- res = Combine(res, Hash(op->false_value));
- return res;
-}
-
-
-// Default case
-bool AttrsEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
- if (lhs.same_as(rhs)) return true;
- if (handler_ == nullptr) {
- return AttrsEqualHandler().Equal(lhs, rhs);
- } else {
- return handler_->Equal(lhs, rhs);
- }
-}
-
-size_t AttrsHash::operator()(const ObjectRef& node) const {
- if (!node.defined()) return 0;
- if (handler_ == nullptr) {
- return AttrsHashHandler().Hash(node);
- } else {
- return handler_->Hash(node);
- }
-}
-
-size_t DictAttrsNode::ContentHash(AttrsHash hasher) const {
- return hasher(this->dict);
-}
-
-bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const {
- if (this == other) return true;
- if (other == nullptr) return false;
- if (this->type_index() != other->type_index()) return false;
- return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
-}
-
TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo")
.set_body_typed([](Attrs attrs) {
return attrs->ListFieldInfo();
});
-TVM_REGISTER_GLOBAL("ir.AttrsEqual")
-.set_body_typed([](ObjectRef lhs, ObjectRef rhs) {
- return AttrsEqual()(lhs, rhs);
-});
-
} // namespace tvm
diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc
index df7b8ff..b2191c1 100644
--- a/src/node/structural_equal.cc
+++ b/src/node/structural_equal.cc
@@ -103,6 +103,7 @@ class RemapVarSEqualHandler :
// Function that implements actual equality check.
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
+ if (!lhs.defined() && !rhs.defined()) return true;
task_stack_.clear();
pending_tasks_.clear();
equal_map_lhs_.clear();
diff --git a/src/relay/transforms/combine_parallel_conv2d.cc
b/src/relay/transforms/combine_parallel_conv2d.cc
index 0dbce9b..3884dac 100644
--- a/src/relay/transforms/combine_parallel_conv2d.cc
+++ b/src/relay/transforms/combine_parallel_conv2d.cc
@@ -59,7 +59,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
}
bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
- AttrsEqual eq;
+ StructuralEqual eq;
const Layout kOIHW("OIHW");
const auto* attrs_a = a->attrs.as<Conv2DAttrs>();
const auto* attrs_b = b->attrs.as<Conv2DAttrs>();
@@ -112,7 +112,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
}
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
- AttrsEqual eq;
+ StructuralEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
auto toutput_a = a->type_as<TensorTypeNode>();
diff --git a/src/relay/transforms/combine_parallel_dense.cc
b/src/relay/transforms/combine_parallel_dense.cc
index cd234bb..612dae5 100644
--- a/src/relay/transforms/combine_parallel_dense.cc
+++ b/src/relay/transforms/combine_parallel_dense.cc
@@ -54,7 +54,7 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner {
protected:
virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
- AttrsEqual eq;
+ StructuralEqual eq;
const auto* attrs_a = a->attrs.as<DenseAttrs>();
const auto* attrs_b = b->attrs.as<DenseAttrs>();
CHECK(attrs_a);
diff --git a/src/relay/transforms/combine_parallel_op.cc
b/src/relay/transforms/combine_parallel_op.cc
index 6b9926c..a7f7af2 100644
--- a/src/relay/transforms/combine_parallel_op.cc
+++ b/src/relay/transforms/combine_parallel_op.cc
@@ -23,6 +23,7 @@
* \brief Abstract class to combine parallel ops and their successive
element-wise ops.
*/
+#include <tvm/node/structural_hash.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
@@ -155,7 +156,7 @@ void ParallelOpCombiner::CombineBranches(const Group&
branches) {
bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth,
size_t parent_index) {
const CallNode* call = branches[0][depth];
- AttrsEqual attrs_equal;
+ tvm::StructuralEqual attrs_equal;
// check if all branches in current depth can be combined
for (auto it = branches.begin() + 1; it != branches.end(); it++) {
const Branch& branch = *it;
diff --git a/src/relay/transforms/combine_parallel_op_batch.cc
b/src/relay/transforms/combine_parallel_op_batch.cc
index fa63573..55ca3f6 100644
--- a/src/relay/transforms/combine_parallel_op_batch.cc
+++ b/src/relay/transforms/combine_parallel_op_batch.cc
@@ -76,7 +76,7 @@ bool ParallelOpBatchCombiner::CanOpsBeCombined(const
CallNode* a, const CallNode
return false;
}
- AttrsEqual eq;
+ StructuralEqual eq;
for (size_t i = 0; i < a->args.size(); i++) {
auto ta = a->args[i]->type_as<TensorTypeNode>();
auto tb = b->args[i]->type_as<TensorTypeNode>();
@@ -112,7 +112,7 @@ Call ParallelOpBatchCombiner::MakeCombinedOp(const Group&
branches) {
}
bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const
CallNode* b, size_t index) {
- AttrsEqual eq;
+ StructuralEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
diff --git a/src/relay/transforms/eliminate_common_subexpr.cc
b/src/relay/transforms/eliminate_common_subexpr.cc
index bb31d32..f905ba5 100644
--- a/src/relay/transforms/eliminate_common_subexpr.cc
+++ b/src/relay/transforms/eliminate_common_subexpr.cc
@@ -45,7 +45,7 @@ class CommonSubexprEliminator : public ExprMutator {
const CallNode* new_call = new_expr.as<CallNode>();
CHECK(new_call);
const OpNode* op = new_call->op.as<OpNode>();
- AttrsEqual attrs_equal;
+ StructuralEqual attrs_equal;
if (new_call->args.size() == 0 || op == nullptr ||
op_stateful.get(GetRef<Op>(op), false)) {
return new_expr;
diff --git a/src/relay/transforms/fold_scale_axis.cc
b/src/relay/transforms/fold_scale_axis.cc
index c3114c7..49f6e3f 100644
--- a/src/relay/transforms/fold_scale_axis.cc
+++ b/src/relay/transforms/fold_scale_axis.cc
@@ -765,7 +765,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
Message AddSubBackwardPrep(const Call& call, const Array<Message>&
in_messages) {
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>();
const auto* trhs = call->args[1]->type_as<TensorTypeNode>();
- AttrsEqual equal;
+ StructuralEqual equal;
if (in_messages[0].defined() &&
MatchBroadcastToLeftAxes(tlhs, trhs, in_messages[0]->axes)) {
return in_messages[0];
@@ -795,7 +795,7 @@ Expr AddSubBackwardTransform(const Call& call,
}
Message lhs_message = transformer->GetMessage(call->args[0]);
Message rhs_message = transformer->GetMessage(call->args[1]);
- AttrsEqual equal;
+ StructuralEqual equal;
if (lhs_message.defined() && rhs_message.defined()) {
CHECK(equal(lhs_message->axes, rhs_message->axes));
diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc
index 6e95441..9168898 100644
--- a/src/relay/transforms/fuse_ops.cc
+++ b/src/relay/transforms/fuse_ops.cc
@@ -162,7 +162,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
// The output.
IndexedForwardGraph graph_;
// attribute equal comparator
- AttrsEqual attr_equal_;
+ StructuralEqual attr_equal_;
// Update the message stored at the node.
void Update(const Expr& node,
IndexedForwardGraph::Node* parent,
diff --git a/src/relay/transforms/pattern_util.h
b/src/relay/transforms/pattern_util.h
index 8ce42a2..350d9e1 100644
--- a/src/relay/transforms/pattern_util.h
+++ b/src/relay/transforms/pattern_util.h
@@ -104,7 +104,7 @@ inline bool MatchBroadcastToLeftAxes(const TensorTypeNode*
tlhs,
const Array<Integer>& lhs_axes,
Expr* rhs_value = nullptr) {
if (tlhs->shape.size() < trhs->shape.size()) return false;
- AttrsEqual equal;
+ StructuralEqual equal;
size_t base = tlhs->shape.size() - trhs->shape.size();
size_t j = 0;
diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc
index 233bfa5..46d0f67 100644
--- a/src/tir/pass/ffi_api.cc
+++ b/src/tir/pass/ffi_api.cc
@@ -101,18 +101,6 @@ TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
return RewriteForTensorCore(stmt, schedule, extern_buffer);
});
-TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual")
-.set_body_typed(
- [](const ObjectRef& lhs, const ObjectRef& rhs) {
- return AttrsEqual()(lhs, rhs);
- });
-
-TVM_REGISTER_GLOBAL("ir_pass.AttrsHash")
-.set_body_typed([](const ObjectRef &node) -> int64_t {
- return AttrsHash()(node);
-});
-
-
TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
diff --git a/tests/python/relay/test_ir_nodes.py
b/tests/python/relay/test_ir_nodes.py
index 6d4a685..dbd5934 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -106,7 +106,6 @@ def test_function():
check_json_roundtrip(fn)
[email protected](reason="AttrsEqualHandler doesn't handle Map so far.")
def test_function_attrs():
param_names = ['a', 'b', 'c', 'd']
params = tvm.runtime.convert([relay.var(n, shape=(5, 2)) for n in
param_names])
diff --git a/tests/python/unittest/test_ir_attrs.py
b/tests/python/unittest/test_ir_attrs.py
index f4148ca..8f2e9bb 100644
--- a/tests/python/unittest/test_ir_attrs.py
+++ b/tests/python/unittest/test_ir_attrs.py
@@ -51,14 +51,13 @@ def test_dict_attrs():
def test_attrs_equal():
- attr_equal = tvm.ir._ffi_api.AttrsEqual
dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20])
dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1)
dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None)
- assert attr_equal(dattr0, dattr1)
- assert not attr_equal(dattr0, dattr2)
- assert not attr_equal({"x": 1}, tvm.runtime.convert(1))
- assert not attr_equal([1, 2], tvm.runtime.convert(1))
+ assert tvm.ir.structural_equal(dattr0, dattr1)
+ assert not tvm.ir.structural_equal(dattr0, dattr2)
+ assert not tvm.ir.structural_equal({"x": 1}, tvm.runtime.convert(1))
+ assert not tvm.ir.structural_equal([1, 2], tvm.runtime.convert(1))
diff --git a/tests/python/unittest/test_tir_pass_attrs_hash_equal.py
b/tests/python/unittest/test_tir_pass_attrs_hash_equal.py
index b3587cd..9a115be 100644
--- a/tests/python/unittest/test_tir_pass_attrs_hash_equal.py
+++ b/tests/python/unittest/test_tir_pass_attrs_hash_equal.py
@@ -21,28 +21,28 @@ def test_attrs_equal():
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
z = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4,1))
- assert tvm.tir.ir_pass.AttrsEqual(x, y)
- assert not tvm.tir.ir_pass.AttrsEqual(x, z)
+ assert tvm.ir.structural_equal(x, y)
+ assert not tvm.ir.structural_equal(x, z)
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
- assert not tvm.tir.ir_pass.AttrsEqual(dattr, x)
+ assert not tvm.ir.structural_equal(dattr, x)
dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz",
padding=(0,0))
- assert tvm.tir.ir_pass.AttrsEqual(dattr, dattr2)
+ assert tvm.ir.structural_equal(dattr, dattr2)
- assert tvm.tir.ir_pass.AttrsEqual({"x": x}, {"x": y})
+ assert tvm.ir.structural_equal({"x": x}, {"x": y})
# array related checks
- assert tvm.tir.ir_pass.AttrsEqual({"x": [x, x]}, {"x": [y, x]})
- assert not tvm.tir.ir_pass.AttrsEqual({"x": [x, 1]}, {"x": [y, 2]})
+ assert tvm.ir.structural_equal({"x": [x, x]}, {"x": [y, x]})
+ assert not tvm.ir.structural_equal({"x": [x, 1]}, {"x": [y, 2]})
n = te.var("n")
- assert tvm.tir.ir_pass.AttrsEqual({"x": n+1}, {"x": n+1})
+ assert tvm.ir.structural_equal({"x": n+1}, {"x": n+1})
def test_attrs_hash():
- fhash = tvm.tir.ir_pass.AttrsHash
+ fhash = tvm.ir.structural_hash
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
y = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3, 4))
assert fhash({"x": x}) == fhash({"x": y})