This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 29a3d3a [Bugfix][IR][ATTRS] Fix AttrEqual for Array and StrMap,
double (#5054)
29a3d3a is described below
commit 29a3d3a66f31235eb644b38d9e03c156fa5fde7f
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Mar 12 16:29:06 2020 -0700
[Bugfix][IR][ATTRS] Fix AttrEqual for Array and StrMap, double (#5054)
- Use fuzzy comparison for double.
- Removed the hack for BatchNormAttrs and DictAttr.
Also removed a warning from text printer printing.
---
include/tvm/ir/attrs.h | 7 ++-
src/ir/attrs.cc | 29 +++++++----
src/printer/doc.cc | 10 ++--
src/printer/doc.h | 5 ++
src/printer/meta_data.h | 2 +-
src/relay/analysis/alpha_equal.cc | 88 +++++++++++++++-------------------
tests/python/unittest/test_ir_attrs.py | 15 +++++-
7 files changed, 92 insertions(+), 64 deletions(-)
diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index 899db08..4413fc3 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -143,8 +143,13 @@ class AttrsEqualHandler;
class AttrsEqual {
public:
bool operator()(const double& lhs, const double& rhs) const {
- return lhs == rhs;
+ // fuzzy float pt comparison
+ constexpr double atol = 1e-9;
+ if (lhs == rhs) return true;
+ double diff = lhs - rhs;
+ return diff > -atol && diff < atol;
}
+
bool operator()(const int64_t& lhs, const int64_t& rhs) const {
return lhs == rhs;
}
diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc
index 4c4c997..868fec6 100644
--- a/src/ir/attrs.cc
+++ b/src/ir/attrs.cc
@@ -79,7 +79,8 @@ using namespace tir;
// Equal handler.
bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
if (lhs.same_as(rhs)) return true;
- if (!lhs.defined() || !rhs.defined()) return false;
+ if (!lhs.defined() && rhs.defined()) return false;
+ if (!rhs.defined() && lhs.defined()) return false;
return this->VisitAttr(lhs, rhs);
}
@@ -96,22 +97,25 @@ bool AttrsEqualHandler::VisitAttrDefault_(const Object*
lhs, const ObjectRef& ot
bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef&
other) {
if (const auto* rhs = other.as<IntImmNode>()) {
return lhs->value == rhs->value;
+ } else {
+ return false;
}
- return false;
}
bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef&
other) {
if (const auto* rhs = other.as<FloatImmNode>()) {
return lhs->value == rhs->value;
+ } else {
+ return false;
}
- return false;
}
bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef&
other) {
if (const auto* rhs = other.as<StringImmNode>()) {
return lhs->value == rhs->value;
+ } else {
+ return false;
}
- return false;
}
bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef&
other) {
@@ -120,8 +124,10 @@ bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs,
const ObjectRef& other)
for (size_t i = 0; i < lhs->data.size(); ++i) {
if (!Equal(lhs->data[i], rhs->data[i])) return false;
}
+ return true;
+ } else {
+ return false;
}
- return true;
}
bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef&
other) {
@@ -132,8 +138,10 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs,
const ObjectRef& other
if (it == rhs->data.end()) return false;
if (!Equal(kv.second, it->second)) return false;
}
+ return true;
+ } else {
+ return false;
}
- return true;
}
#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName) \
@@ -340,8 +348,13 @@ bool DictAttrsNode::ContentEqual(const Object* other,
AttrsEqual equal) const {
}
TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo")
-.set_body([](TVMArgs args, TVMRetValue* ret) {
- *ret = args[0].operator Attrs()->ListFieldInfo();
+.set_body_typed([](Attrs attrs) {
+ return attrs->ListFieldInfo();
+});
+
+TVM_REGISTER_GLOBAL("ir.AttrsEqual")
+.set_body_typed([](ObjectRef lhs, ObjectRef rhs) {
+ return AttrsEqual()(lhs, rhs);
});
} // namespace tvm
diff --git a/src/printer/doc.cc b/src/printer/doc.cc
index c5595db..ee260f4 100644
--- a/src/printer/doc.cc
+++ b/src/printer/doc.cc
@@ -40,9 +40,6 @@ class DocTextNode : public DocAtomNode {
explicit DocTextNode(std::string str_val)
: str(str_val) {
- if (str.find_first_of("\t\n") != str.npos) {
- LOG(WARNING) << "text node: '" << str << "' should not has tab or
newline.";
- }
}
static constexpr const char* _type_key = "printer.DocText";
@@ -54,6 +51,9 @@ TVM_REGISTER_OBJECT_TYPE(DocTextNode);
class DocText : public DocAtom {
public:
explicit DocText(std::string str) {
+ if (str.find_first_of("\t\n") != str.npos) {
+ LOG(WARNING) << "text node: '" << str << "' should not has tab or
newline.";
+ }
data_ = runtime::make_object<DocTextNode>(str);
}
@@ -125,6 +125,10 @@ Doc Doc::Text(std::string text) {
return Doc() << DocText(text);
}
+Doc Doc::RawText(std::string text) {
+ return Doc() << DocAtom(runtime::make_object<DocTextNode>(text));
+}
+
Doc Doc::Indent(int indent, Doc doc) {
for (size_t i = 0; i < doc.stream_.size(); ++i) {
if (auto* line = doc.stream_[i].as<DocLineNode>()) {
diff --git a/src/printer/doc.h b/src/printer/doc.h
index 34a284b..7d8d72e 100644
--- a/src/printer/doc.h
+++ b/src/printer/doc.h
@@ -111,6 +111,11 @@ class Doc {
*/
static Doc Text(std::string value);
/*!
+ * \brief Create a doc that represents raw text(can have new lines)
+ * \return The created doc.
+ */
+ static Doc RawText(std::string value);
+ /*!
* \brief Create a doc that represents a new line.
* \return The created doc.
*/
diff --git a/src/printer/meta_data.h b/src/printer/meta_data.h
index 6c300fd..d390692 100644
--- a/src/printer/meta_data.h
+++ b/src/printer/meta_data.h
@@ -121,7 +121,7 @@ class TextMetaDataContext {
*/
Doc GetMetaSection() const {
if (meta_data_.size() == 0) return Doc();
- return Doc::Text(
+ return Doc::RawText(
SaveJSON(Map<std::string, ObjectRef>(meta_data_.begin(),
meta_data_.end())));
}
diff --git a/src/relay/analysis/alpha_equal.cc
b/src/relay/analysis/alpha_equal.cc
index 8a07a19..726ccbb 100644
--- a/src/relay/analysis/alpha_equal.cc
+++ b/src/relay/analysis/alpha_equal.cc
@@ -30,6 +30,8 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
#include "../../ir/attr_functor.h"
+
+
namespace tvm {
namespace relay {
@@ -50,37 +52,7 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
- if (!lhs.defined() || !rhs.defined()) return false;
- if (lhs.same_as(rhs)) return true;
- if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
- if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return
false;
- return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
- }
- if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
- if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return
false;
- return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
- }
- if (const auto lhsm = lhs.as<IRModuleNode>()) {
- auto rhsm = rhs.as<IRModuleNode>();
- if (!rhsm) return false;
- if (lhsm->functions.size() != rhsm->functions.size()) return false;
- for (const auto& p : lhsm->functions) {
- if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
- }
- if (lhsm->type_definitions.size() != rhsm->type_definitions.size())
return false;
- for (const auto& p : lhsm->type_definitions) {
- if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
- !Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
- return false;
- }
- }
- return true;
- }
- return AttrEqual(lhs, rhs);
- }
-
- bool DoubleEqual(double l, double r) {
- return true;
+ return VisitAttr(lhs, rhs);
}
/*!
* Check equality of two attributes.
@@ -90,25 +62,7 @@ class AlphaEqualHandler:
*/
bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
auto compute = [&]() {
- if (&lhs == &rhs) return true;
- if (auto lhsd = lhs.as<DictAttrsNode>()) {
- auto rhsd = rhs.as<DictAttrsNode>();
- if (!rhsd) return false;
- if (lhsd->dict.size() != rhsd->dict.size()) return false;
- for (const auto& k : lhsd->dict) {
- if (!Equal(k.second, rhsd->dict[k.first])) return false;
- }
- return true;
- }
- if (auto lhsbn = lhs.as<BatchNormAttrs>()) {
- auto rhsbn = rhs.as<BatchNormAttrs>();
- if (!rhsbn) return false;
- return (lhsbn->axis == rhsbn->axis)
- && DoubleEqual(lhsbn->epsilon, rhsbn->epsilon)
- && (lhsbn->center == rhsbn->center)
- && (lhsbn->scale == rhsbn->scale);
- }
- return AttrsEqualHandler::Equal(lhs, rhs);
+ return VisitAttr(lhs, rhs);
};
return Compare(compute(), lhs, rhs);
}
@@ -164,6 +118,40 @@ class AlphaEqualHandler:
}
protected:
+ // So that the new definition of equality in relay can be handled directly.
+ // Specifically, if a DictAttr contains a value defined by a relay AST.
+ // We want to able to recursively check the equality in the attr defined by
the relay AST.
+ bool VisitAttr(const ObjectRef& lhs, const ObjectRef& rhs) final {
+ if (lhs.same_as(rhs)) return true;
+ if (!lhs.defined() && rhs.defined()) return false;
+ if (!rhs.defined() && lhs.defined()) return false;
+ if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
+ if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return
false;
+ return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
+ }
+ if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
+ if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return
false;
+ return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
+ }
+ if (const auto lhsm = lhs.as<IRModuleNode>()) {
+ auto rhsm = rhs.as<IRModuleNode>();
+ if (!rhsm) return false;
+ if (lhsm->functions.size() != rhsm->functions.size()) return false;
+ for (const auto& p : lhsm->functions) {
+ if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
+ }
+ if (lhsm->type_definitions.size() != rhsm->type_definitions.size())
return false;
+ for (const auto& p : lhsm->type_definitions) {
+ if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
+ !Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
+ return false;
+ }
+ }
+ return true;
+ }
+ // Fall back to the object equal case.
+ return AttrsEqualHandler::VisitAttr(lhs, rhs);
+ }
/*!
* \brief Check if data type equals each other.
* \param lhs The left hand operand.
diff --git a/tests/python/unittest/test_ir_attrs.py
b/tests/python/unittest/test_ir_attrs.py
index a2be2b7..f4148ca 100644
--- a/tests/python/unittest/test_ir_attrs.py
+++ b/tests/python/unittest/test_ir_attrs.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import te
+import tvm.ir._ffi_api
def test_make_attrs():
try:
@@ -50,6 +50,19 @@ def test_dict_attrs():
assert len(dattr.items()) == 4
+def test_attrs_equal():
+ attr_equal = tvm.ir._ffi_api.AttrsEqual
+ dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20])
+ dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1)
+ dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None)
+ assert attr_equal(dattr0, dattr1)
+ assert not attr_equal(dattr0, dattr2)
+ assert not attr_equal({"x": 1}, tvm.runtime.convert(1))
+ assert not attr_equal([1, 2], tvm.runtime.convert(1))
+
+
+
if __name__ == "__main__":
test_make_attrs()
test_dict_attrs()
+ test_attrs_equal()