This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 997a14e [NODE][IR] Introduce StructuralEqual Infra for the unified
IR. (#5154)
997a14e is described below
commit 997a14eda9aec3b343e742e55c3018f9dc23d8c3
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Mar 27 22:21:00 2020 -0700
[NODE][IR] Introduce StructuralEqual Infra for the unified IR. (#5154)
* [NODE][IR] Introduce StructuralEqual Infra for the Unified IR.
This PR introduces a new way to handle structural equality
for both TIR and relay nodes in an extensive way.
- Each object can now register an optional SEqualReduce function, which
describes how to reduce its structural equality to another instance
into equality of the children.
- Optionally, the object can choose to allow remapping of vars(e.g.
function parameters)
by calling DefEqual
- We implemented a non-recursive structural equality checker that
recursively traverses the objects and does the structural equality
checking.
This PR also fixes a few potential problems in previous relay's AlphaEqual.
- In particular, the new structural equality relation will be communicative.
- It is can be dangerous to use same_as relation to quickly check equality,
demonstrated by the following case. (%x, %y) are shared vars between two
functions.
- function0: fn (%x, %y) { %x + %y }
- function1: fn (%y, %x) { %x + %y }
The new structural equal is intented to supersede AlphaEqual and AttrsEqual.
Follow-up PRs should be performed to redirect the existing usages, and
removes
the corresponding implementation.
* Update the rule to distinguish between graph node and non-graph nodes.
* Refactor the test cases to use structural equal.
* address comments
* Mark more relay::Expr as graph node, fix a testcase issue(was bug that
was not caught by previous alpha equal)
* Remove unrelated comment
* Fix file comment
* Address review comment
* Relax condition to fit flaky case
---
include/tvm/arith/analyzer.h | 8 +
include/tvm/arith/int_set.h | 1 +
include/tvm/ir/adt.h | 15 ++
include/tvm/ir/attrs.h | 41 +++
include/tvm/ir/env_func.h | 5 +
include/tvm/ir/expr.h | 21 ++
include/tvm/ir/module.h | 3 +
include/tvm/ir/op.h | 5 +
include/tvm/ir/span.h | 11 +
include/tvm/ir/tensor_type.h | 6 +
include/tvm/ir/transform.h | 2 +
include/tvm/ir/type.h | 43 ++++
include/tvm/ir/type_relation.h | 14 ++
include/tvm/node/container.h | 20 +-
include/tvm/node/node.h | 2 +
include/tvm/node/reflection.h | 148 +++++++++--
include/tvm/node/structural_equal.h | 225 +++++++++++++++++
include/tvm/relay/adt.h | 32 +++
include/tvm/relay/expr.h | 71 ++++++
include/tvm/relay/function.h | 11 +
include/tvm/runtime/ndarray.h | 22 ++
include/tvm/runtime/object.h | 4 +
include/tvm/tir/buffer.h | 15 ++
include/tvm/tir/expr.h | 153 ++++++++++-
include/tvm/tir/function.h | 11 +
include/tvm/tir/stmt.h | 103 ++++++++
python/tvm/ir/__init__.py | 1 +
python/tvm/ir/base.py | 73 ++++++
src/ir/attr_functor.h | 4 +-
src/ir/expr.cc | 8 +-
src/ir/module.cc | 19 +-
src/node/container.cc | 140 +++++++++++
src/node/reflection.cc | 2 +-
src/node/structural_equal.cc | 241 ++++++++++++++++++
src/tir/ir/expr.cc | 18 +-
tests/python/frontend/tensorflow/test_forward.py | 2 +-
tests/python/relay/test_ir_parser.py | 109 ++++----
..._alpha_equal.py => test_ir_structural_equal.py} | 280 +++++++++++----------
.../relay/test_pass_dead_code_elimination.py | 14 +-
tests/python/relay/test_pass_partial_eval.py | 26 +-
tests/python/relay/test_pass_qnn_legalize.py | 8 +-
tests/python/relay/test_pass_to_a_normal_form.py | 4 +-
tests/python/relay/test_pass_to_cps.py | 2 +-
tests/python/relay/test_type_infer.py | 3 +-
tests/python/unittest/test_node_reflection.py | 4 +-
tests/python/unittest/test_tir_structural_equal.py | 102 ++++++++
46 files changed, 1781 insertions(+), 271 deletions(-)
diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index 31f2216..e7f5ede 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -68,6 +68,10 @@ class ConstIntBoundNode : public Object {
v->Visit("max_value", &max_value);
}
+ bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const
{
+ return equal(min_value, other->min_value) && equal(max_value,
other->max_value);
+ }
+
/*! \brief Number to represent +inf */
static const constexpr int64_t kPosInf = std::numeric_limits<int64_t>::max();
/*!
@@ -170,6 +174,10 @@ class ModularSetNode : public Object {
v->Visit("base", &base);
}
+ bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const {
+ return equal(coeff, other->coeff) && equal(base, other->base);
+ }
+
static constexpr const char* _type_key = "arith.ModularSet";
TVM_DECLARE_FINAL_OBJECT_INFO(ModularSetNode, Object);
};
diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h
index 8b73f87..86ef906 100644
--- a/include/tvm/arith/int_set.h
+++ b/include/tvm/arith/int_set.h
@@ -59,6 +59,7 @@ enum SignType {
class IntSetNode : public Object {
public:
static constexpr const char* _type_key = "IntSet";
+ static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object);
};
diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h
index 67cfb8d..2601614 100644
--- a/include/tvm/ir/adt.h
+++ b/include/tvm/ir/adt.h
@@ -63,6 +63,14 @@ class ConstructorNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const ConstructorNode* other, SEqualReducer equal) const {
+ // Use namehint for now to be consistent with the legacy relay impl
+ // TODO(tvm-team) revisit, need to check the type var.
+ return
+ equal(name_hint, other->name_hint) &&
+ equal(inputs, other->inputs);
+ }
+
static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstructorNode, RelayExprNode);
};
@@ -108,6 +116,13 @@ class TypeDataNode : public TypeNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const TypeDataNode* other, SEqualReducer equal) const {
+ return
+ equal.DefEqual(header, other->header) &&
+ equal.DefEqual(type_vars, other->type_vars) &&
+ equal(constructors, other->constructors);
+ }
+
static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeDataNode, TypeNode);
};
diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index 4413fc3..c3b5831 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -118,7 +118,9 @@ class AttrFieldInfoNode : public Object {
v->Visit("type_info", &type_info);
v->Visit("description", &description);
}
+
static constexpr const char* _type_key = "AttrFieldInfo";
+ static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};
@@ -278,6 +280,7 @@ class BaseAttrsNode : public Object {
*/
TVM_DLL virtual size_t ContentHash(AttrsHash hasher) const = 0;
+ static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const char* _type_key = "Attrs";
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};
@@ -302,6 +305,10 @@ class DictAttrsNode : public BaseAttrsNode {
/*! \brief internal attrs map */
Map<std::string, ObjectRef> dict;
+ bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
+ return equal(dict, other->dict);
+ }
+
// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
@@ -401,6 +408,33 @@ class AttrsEqualVisitor {
const AttrsEqual& equal_;
};
+class AttrsSEqualVisitor {
+ public:
+ bool result_{true};
+ // constructor
+ AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const
SEqualReducer& equal)
+ : lhs_(lhs), rhs_(rhs), equal_(equal) {
+ }
+ template<typename T>
+ AttrNopEntry operator()(const char* key, T* lhs_value) {
+ if (!result_) return AttrNopEntry();
+ const T* rhs_value =
+ reinterpret_cast<const T*>(
+ reinterpret_cast<const char*>(rhs_) +
+ (reinterpret_cast<const char*>(lhs_value) -
+ reinterpret_cast<const char*>(lhs_)));
+ if (!equal_(*lhs_value, *rhs_value)) {
+ result_ = false;
+ }
+ return AttrNopEntry();
+ }
+
+ private:
+ const Object* lhs_;
+ const Object* rhs_;
+ const SEqualReducer& equal_;
+};
+
class AttrsHashVisitor {
public:
explicit AttrsHashVisitor(const AttrsHash& hasher)
@@ -817,6 +851,13 @@ class AttrsNode : public BaseAttrsNode {
}
}
+ bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
+ DerivedType* pself = self();
+ ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
+ self()->__VisitAttrs__(visitor);
+ return visitor.result_;
+ }
+
Array<AttrFieldInfo> ListFieldInfo() const final {
::tvm::detail::AttrDocVisitor visitor;
self()->__VisitAttrs__(visitor);
diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h
index f5b17bb..1064fd1 100644
--- a/include/tvm/ir/env_func.h
+++ b/include/tvm/ir/env_func.h
@@ -51,7 +51,12 @@ class EnvFuncNode : public Object {
v->Visit("name", &name);
}
+ bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
+ return this == other;
+ }
+
static constexpr const char* _type_key = "EnvFunc";
+ static constexpr bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object);
};
diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h
index 44244df..fc63da0 100644
--- a/include/tvm/ir/expr.h
+++ b/include/tvm/ir/expr.h
@@ -43,6 +43,7 @@ namespace tvm {
class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};
@@ -197,6 +198,13 @@ class GlobalVarNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const {
+ // name matters for global var.
+ return
+ equal(name_hint, other->name_hint) &&
+ equal.FreeVarEqualImpl(this, other);
+ }
+
static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};
@@ -228,6 +236,10 @@ class IntImmNode : public PrimExprNode {
v->Visit("value", &value);
}
+ bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
+ return equal(dtype, other->dtype) && equal(value, other->value);
+ }
+
static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};
@@ -263,6 +275,10 @@ class FloatImmNode : public PrimExprNode {
v->Visit("value", &value);
}
+ bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const {
+ return equal(dtype, other->dtype) && equal(value, other->value);
+ }
+
static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
@@ -353,7 +369,12 @@ class RangeNode : public Object {
v->Visit("extent", &extent);
}
+ bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const {
+ return equal(min, other->min) && equal(extent, other->extent);
+ }
+
static constexpr const char* _type_key = "Range";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object);
};
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index 4613bec..38e583d 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -62,6 +62,8 @@ class IRModuleNode : public Object {
v->Visit("global_type_var_map_", &global_type_var_map_);
}
+ TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal)
const;
+
/*!
* \brief Add a function to the global environment.
* \param var The var of the global function.
@@ -235,6 +237,7 @@ class IRModuleNode : public Object {
TVM_DLL std::unordered_set<std::string> Imports() const;
static constexpr const char* _type_key = "IRModule";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
private:
diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h
index 8a6ab77..f023e87 100644
--- a/include/tvm/ir/op.h
+++ b/include/tvm/ir/op.h
@@ -101,6 +101,11 @@ class OpNode : public RelayExprNode {
v->Visit("support_level", &support_level);
}
+ bool SEqualReduce(const OpNode* other, SEqualReducer equal) const {
+ // pointer equality is fine as there is only one op with the same name.
+ return this == other;
+ }
+
/*!
* \brief Check that if current op is a "primtive operator".
* That is the arguments are all type variables, and there is a single
diff --git a/include/tvm/ir/span.h b/include/tvm/ir/span.h
index 4720dfe..7194e90 100644
--- a/include/tvm/ir/span.h
+++ b/include/tvm/ir/span.h
@@ -44,6 +44,10 @@ class SourceNameNode : public Object {
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
+ bool SEqualReduce(const SourceNameNode* other, SEqualReducer equal) const {
+ return equal(name, other->name);
+ }
+
static constexpr const char* _type_key = "SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
};
@@ -87,6 +91,13 @@ class SpanNode : public Object {
v->Visit("col_offset", &col_offset);
}
+ bool SEqualReduce(const SpanNode* other, SEqualReducer equal) const {
+ return
+ equal(source, other->source) &&
+ equal(lineno, other->lineno) &&
+ equal(col_offset, other->col_offset);
+ }
+
TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
static constexpr const char* _type_key = "Span";
diff --git a/include/tvm/ir/tensor_type.h b/include/tvm/ir/tensor_type.h
index 70a2df1..05c7a95 100644
--- a/include/tvm/ir/tensor_type.h
+++ b/include/tvm/ir/tensor_type.h
@@ -73,6 +73,12 @@ class TensorTypeNode : public BaseTensorTypeNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {
+ return
+ equal(shape, other->shape) &&
+ equal(dtype, other->dtype);
+ }
+
/*! \brief Return product of elements in the shape.
* \return (d1 * d_2 ... * d_n) if shape is (d_1, d_2, ..., d_n) and 1 if
shape size is zero.
*/
diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h
index 1b6ea25..ecd234a 100644
--- a/include/tvm/ir/transform.h
+++ b/include/tvm/ir/transform.h
@@ -111,6 +111,7 @@ class PassContextNode : public Object {
}
static constexpr const char* _type_key = "transform.PassContext";
+ static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};
@@ -207,6 +208,7 @@ class PassInfoNode : public Object {
}
static constexpr const char* _type_key = "transform.PassInfo";
+ static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
};
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index c23626e..dd70029 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -79,6 +79,7 @@ class TypeNode : public Object {
mutable Span span;
static constexpr const char* _type_key = "Type";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};
@@ -110,6 +111,10 @@ class PrimTypeNode : public TypeNode {
v->Visit("dtype", &dtype);
}
+ bool SEqualReduce(const PrimTypeNode* other, SEqualReducer equal) const {
+ return equal(dtype, other->dtype);
+ }
+
static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
@@ -152,6 +157,10 @@ class PointerTypeNode : public TypeNode {
v->Visit("element_type", &element_type);
}
+ bool SEqualReduce(const PointerTypeNode* other, SEqualReducer equal) const {
+ return equal(element_type, other->element_type);
+ }
+
static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
};
@@ -218,6 +227,12 @@ class TypeVarNode : public TypeNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const TypeVarNode* other, SEqualReducer equal) const {
+ return
+ equal(kind, other->kind) &&
+ equal.FreeVarEqualImpl(this, other);
+ }
+
static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
@@ -258,6 +273,13 @@ class GlobalTypeVarNode : public TypeNode {
v->Visit("kind", &kind);
}
+ bool SEqualReduce(const GlobalTypeVarNode* other, SEqualReducer equal) const
{
+ // name matters for now in global type var.
+ return
+ equal(name_hint, other->name_hint) &&
+ equal.FreeVarEqualImpl(this, other);
+ }
+
static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
@@ -294,6 +316,10 @@ class TupleTypeNode : public TypeNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const TupleTypeNode* other, SEqualReducer equal) const {
+ return equal(fields, other->fields);
+ }
+
static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
@@ -386,6 +412,15 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const FuncTypeNode* other, SEqualReducer equal) const {
+ // type params first as they defines type vars.
+ return
+ equal.DefEqual(type_params, other->type_params) &&
+ equal(arg_types, other->arg_types) &&
+ equal(ret_type, other->ret_type) &&
+ equal(type_constraints, other->type_constraints);
+ }
+
static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
@@ -432,6 +467,10 @@ class IncompleteTypeNode : public TypeNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal)
const {
+ return equal(kind, other->kind);
+ }
+
static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
@@ -469,6 +508,10 @@ class RelayRefTypeNode : public TypeNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const RelayRefTypeNode* other, SEqualReducer equal) const {
+ return equal(value, other->value);
+ }
+
// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static constexpr const char* _type_key = "relay.RefType";
diff --git a/include/tvm/ir/type_relation.h b/include/tvm/ir/type_relation.h
index f7bfb68..592bf25 100644
--- a/include/tvm/ir/type_relation.h
+++ b/include/tvm/ir/type_relation.h
@@ -50,6 +50,12 @@ class TypeCallNode : public TypeNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const TypeCallNode* other, SEqualReducer equal) const {
+ return
+ equal(func, other->func) &&
+ equal(args, other->args);
+ }
+
static constexpr const char* _type_key = "TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
@@ -195,6 +201,14 @@ class TypeRelationNode : public TypeConstraintNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const TypeRelationNode* other, SEqualReducer equal) const {
+ return
+ equal(func, other->func) &&
+ equal(args, other->args) &&
+ equal(num_inputs, other->num_inputs) &&
+ equal(attrs, other->attrs);
+ }
+
static constexpr const char* _type_key = "TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};
diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h
index a541385..461fa11 100644
--- a/include/tvm/node/container.h
+++ b/include/tvm/node/container.h
@@ -23,7 +23,9 @@
#ifndef TVM_NODE_CONTAINER_H_
#define TVM_NODE_CONTAINER_H_
-#include <tvm/node/node.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/memory.h>
+#include <tvm/runtime/packed_func.h>
#include <type_traits>
#include <vector>
@@ -34,15 +36,19 @@
namespace tvm {
+using runtime::Object;
+using runtime::ObjectPtr;
+using runtime::ObjectRef;
+using runtime::make_object;
+using runtime::ObjectHash;
+using runtime::ObjectEqual;
+
/*! \brief array node content in array */
class ArrayNode : public Object {
public:
/*! \brief the data content */
std::vector<ObjectRef> data;
- void VisitAttrs(AttrVisitor* visitor) {
- }
-
static constexpr const char* _type_key = "Array";
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object);
};
@@ -50,9 +56,6 @@ class ArrayNode : public Object {
/*! \brief map node content */
class MapNode : public Object {
public:
- void VisitAttrs(AttrVisitor* visitor) {
- }
-
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<
ObjectRef,
@@ -73,9 +76,6 @@ class StrMapNode : public Object {
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<std::string, ObjectRef>;
- void VisitAttrs(AttrVisitor* visitor) {
- }
-
/*! \brief the data content */
ContainerType data;
diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h
index 3ea3d76..76e574b 100644
--- a/include/tvm/node/node.h
+++ b/include/tvm/node/node.h
@@ -39,6 +39,8 @@
#include <tvm/runtime/memory.h>
#include <tvm/node/reflection.h>
#include <tvm/node/repr_printer.h>
+#include <tvm/node/container.h>
+#include <tvm/node/structural_equal.h>
#include <string>
#include <vector>
diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h
index daffeb8..d0a9304 100644
--- a/include/tvm/node/reflection.h
+++ b/include/tvm/node/reflection.h
@@ -29,13 +29,14 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/data_type.h>
+#include <tvm/node/structural_equal.h>
#include <vector>
#include <string>
+#include <type_traits>
namespace tvm {
-// forward declaration
using runtime::Object;
using runtime::ObjectPtr;
using runtime::ObjectRef;
@@ -87,6 +88,13 @@ class ReflectionVTable {
*/
typedef void (*FVisitAttrs)(Object* self, AttrVisitor* visitor);
/*!
+ * \brief Equality comparison function.
+ * \note We use function pointer, instead of std::function
+ * to reduce the dispatch overhead as field visit
+ * does not need as much customization.
+ */
+ typedef bool (*FSEqualReduce)(const Object* self, const Object* other,
SEqualReducer equal);
+ /*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey must be defined for the
object.
@@ -112,6 +120,14 @@ class ReflectionVTable {
*/
inline std::string GetGlobalKey(Object* self) const;
/*!
+ * \brief Dispatch the SEqualReduce function.
+ * \param self The pointer to the object.
+ * \param other The pointer to another object to be compared.
+ * \param equal The equality comparator.
+ * \return the result.
+ */
+ bool SEqualReduce(const Object* self, const Object* other, SEqualReducer
equal) const;
+ /*!
* \brief Create an initial object using default constructor
* by type_key and global key.
*
@@ -139,12 +155,14 @@ class ReflectionVTable {
TVM_DLL static ReflectionVTable* Global();
class Registry;
- template<typename T>
+ template<typename T, typename TraitName>
inline Registry Register();
private:
/*! \brief Attribute visitor. */
std::vector<FVisitAttrs> fvisit_attrs_;
+ /*! \brief Structural equal function. */
+ std::vector<FSEqualReduce> fsequal_;
/*! \brief Creation function. */
std::vector<FCreate> fcreate_;
/*! \brief Global key function. */
@@ -182,6 +200,44 @@ class ReflectionVTable::Registry {
uint32_t type_index_;
};
+
+#define TVM_REFLECTION_REG_VAR_DEF \
+ static TVM_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry \
+ __make_reflectiion
+
+/*!
+ * \brief Directly register reflection VTable.
+ * \param TypeName The name of the type.
+ * \param TraitName A trait class that implements functions like VisitAttrs
and SEqualReduce.
+ *
+ * \code
+ *
+ * // Example SEQualReduce traits for runtime StringObj.
+ *
+ * struct StringObjTrait {
+ * static constexpr const std::nullptr_t VisitAttrs = nullptr;
+ *
+ * static bool SEqualReduce(const runtime::StringObj* lhs,
+ * const runtime::StringObj* rhs,
+ * SEqualReducer equal) {
+ * if (lhs == rhs) return true;
+ * if (lhs->size != rhs->size) return false;
+ * if (lhs->data != rhs->data) return true;
+ * return std::memcmp(lhs->data, rhs->data, lhs->size) != 0;
+ * }
+ * };
+ *
+ * TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
+ *
+ * \endcode
+ *
+ * \note This macro can be called in different place as
TVM_REGISTER_OBJECT_TYPE.
+ * And can be used to register the related reflection functions for
runtime objects.
+ */
+#define TVM_REGISTER_REFLECTION_VTABLE(TypeName, TraitName) \
+ TVM_STR_CONCAT(TVM_REFLECTION_REG_VAR_DEF, __COUNTER__) = \
+ ::tvm::ReflectionVTable::Global()->Register<TypeName, TraitName>() \
+
/*!
* \brief Register a node type to object registry and reflection registry.
* \param TypeName The name of the type.
@@ -189,15 +245,79 @@ class ReflectionVTable::Registry {
*/
#define TVM_REGISTER_NODE_TYPE(TypeName) \
TVM_REGISTER_OBJECT_TYPE(TypeName); \
- static DMLC_ATTRIBUTE_UNUSED ::tvm::ReflectionVTable::Registry & \
- __make_Node ## _ ## TypeName ## __ = \
- ::tvm::ReflectionVTable::Global()->Register<TypeName>() \
- .set_creator([](const std::string&) -> ObjectPtr<Object> { \
- return ::tvm::runtime::make_object<TypeName>(); \
- })
+ TVM_REGISTER_REFLECTION_VTABLE(TypeName,
::tvm::detail::ReflectionTrait<TypeName>) \
+ .set_creator([](const std::string&) -> ObjectPtr<Object> { \
+ return ::tvm::runtime::make_object<TypeName>(); \
+ })
+
// Implementation details
+namespace detail {
+
+template<typename T,
+ bool = T::_type_has_method_visit_attrs>
+struct ImplVisitAttrs {
+ static constexpr const std::nullptr_t VisitAttrs = nullptr;
+};
+
+template<typename T>
+struct ImplVisitAttrs<T, true> {
+ static void VisitAttrs(T* self, AttrVisitor* v) {
+ self->VisitAttrs(v);
+ }
+};
+
+template<typename T,
+ bool = T::_type_has_method_sequal_reduce>
+struct ImplSEqualReduce {
+ static constexpr const std::nullptr_t SEqualReduce = nullptr;
+};
+
+template<typename T>
+struct ImplSEqualReduce<T, true> {
+ static bool SEqualReduce(const T* self, const T* other, SEqualReducer equal)
{
+ return self->SEqualReduce(other, equal);
+ }
+};
+
template<typename T>
+struct ReflectionTrait :
+ public ImplVisitAttrs<T>,
+ public ImplSEqualReduce<T> {
+};
+
+template<typename T, typename TraitName,
+ bool = std::is_null_pointer<decltype(TraitName::VisitAttrs)>::value>
+struct SelectVisitAttrs {
+ static constexpr const std::nullptr_t VisitAttrs = nullptr;
+};
+
+template<typename T, typename TraitName>
+struct SelectVisitAttrs<T, TraitName, false> {
+ static void VisitAttrs(Object* self, AttrVisitor* v) {
+ TraitName::VisitAttrs(static_cast<T*>(self), v);
+ }
+};
+
+template<typename T, typename TraitName,
+ bool = std::is_null_pointer<decltype(TraitName::SEqualReduce)>::value>
+struct SelectSEqualReduce {
+ static constexpr const std::nullptr_t SEqualReduce = nullptr;
+};
+
+template<typename T, typename TraitName>
+struct SelectSEqualReduce<T, TraitName, false> {
+ static bool SEqualReduce(const Object* self,
+ const Object* other,
+ SEqualReducer equal) {
+ return TraitName::SEqualReduce(static_cast<const T*>(self),
+ static_cast<const T*>(other),
+ equal);
+ }
+};
+} // namespace detail
+
+template<typename T, typename TraitName>
inline ReflectionVTable::Registry
ReflectionVTable::Register() {
uint32_t tindex = T::RuntimeTypeIndex();
@@ -205,15 +325,15 @@ ReflectionVTable::Register() {
fvisit_attrs_.resize(tindex + 1, nullptr);
fcreate_.resize(tindex + 1, nullptr);
fglobal_key_.resize(tindex + 1, nullptr);
+ fsequal_.resize(tindex + 1, nullptr);
}
// functor that implemnts the redirection.
- struct Functor {
- static void VisitAttrs(Object* self, AttrVisitor* v) {
- static_cast<T*>(self)->VisitAttrs(v);
- }
- };
+ fvisit_attrs_[tindex] =
+ ::tvm::detail::SelectVisitAttrs<T, TraitName>::VisitAttrs;
+
+ fsequal_[tindex] =
+ ::tvm::detail::SelectSEqualReduce<T, TraitName>::SEqualReduce;
- fvisit_attrs_[tindex] = Functor::VisitAttrs;
return Registry(this, tindex);
}
diff --git a/include/tvm/node/structural_equal.h
b/include/tvm/node/structural_equal.h
new file mode 100644
index 0000000..f719e24
--- /dev/null
+++ b/include/tvm/node/structural_equal.h
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/node/structural_equal.h
+ * \brief Structural equality comparison.
+ */
+#ifndef TVM_NODE_STRUCTURAL_EQUAL_H_
+#define TVM_NODE_STRUCTURAL_EQUAL_H_
+
+#include <tvm/runtime/data_type.h>
+#include <tvm/node/functor.h>
+#include <tvm/node/container.h>
+#include <string>
+
+namespace tvm {
+
+/*!
+ * \brief Equality definition of base value class.
+ */
+class BaseValueEqual {
+ public:
+ bool operator()(const double& lhs, const double& rhs) const {
+ // fuzzy float pt comparison
+ constexpr double atol = 1e-9;
+ if (lhs == rhs) return true;
+ double diff = lhs - rhs;
+ return diff > -atol && diff < atol;
+ }
+
+ bool operator()(const int64_t& lhs, const int64_t& rhs) const {
+ return lhs == rhs;
+ }
+ bool operator()(const uint64_t& lhs, const uint64_t& rhs) const {
+ return lhs == rhs;
+ }
+ bool operator()(const int& lhs, const int& rhs) const {
+ return lhs == rhs;
+ }
+ bool operator()(const bool& lhs, const bool& rhs) const {
+ return lhs == rhs;
+ }
+ bool operator()(const std::string& lhs, const std::string& rhs) const {
+ return lhs == rhs;
+ }
+ bool operator()(const DataType& lhs, const DataType& rhs) const {
+ return lhs == rhs;
+ }
+ template<typename ENum,
+ typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
+ bool operator()(const ENum& lhs, const ENum& rhs) const {
+ return lhs == rhs;
+ }
+};
+
+/*!
+ * \brief Content-aware structural equality comparator for objects.
+ *
+ * The structural equality is recursively defined in the DAG of IR nodes via
SEqual.
+ * There are two kinds of nodes:
+ *
+ * - Graph node: a graph node in lhs can only be mapped as equal to
+ * one and only one graph node in rhs.
+ * - Normal node: equality is recursively defined without the restriction
+ * of graph nodes.
+ *
+ * Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph
nodes.
+ * For example, it means that `%1 = %x + %y; %1 + %1` is not structurally
equal
+ * to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
+ *
+ * A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another
var
+ * with the same type if one of the following condition holds:
+ *
+ * - They appear in a same definition point(e.g. function argument).
+ * - They points to the same VarNode via the same_as relation.
+ * - They appear in a same usage point, and map_free_vars is set to be True.
+ */
+class StructuralEqual : public BaseValueEqual {
+ public:
+ // inheritate operator()
+ using BaseValueEqual::operator();
+ /*!
+ * \brief Compare objects via strutural equal.
+ * \param lhs The left operand.
+ * \param rhs The right operand.
+ * \return The comparison result.
+ */
+ TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;
+};
+
+/*!
+ * \brief A Reducer class to reduce the structural equality result of two
objects.
+ *
+ * The reducer will call the SEqualReduce function of each objects recursively.
+ * Importantly, the reducer may not directly use recursive calls to resolve the
+ * equality checking. Instead, it can store the necessary equality conditions
+ * and check later via an internally managed stack.
+ */
+class SEqualReducer : public BaseValueEqual {
+ public:
+ /*! \brief Internal handler that defines custom behaviors.. */
+ class Handler {
+ public:
+ /*!
+ * \brief Reduce condition to equality of lhs and rhs.
+ *
+ * \param lhs The left operand.
+ * \param rhs The right operand.
+ * \param map_free_vars Whether do we allow remap variables if possible.
+ *
+ * \return false if there is an immediate failure, true otherwise.
+ * \note This function may save the equality condition of (lhs == rhs) in
an internal
+ * stack and try to resolve later.
+ */
+ virtual bool SEqualReduce(const ObjectRef& lhs,
+ const ObjectRef& rhs,
+ bool map_free_vars) = 0;
+ /*!
+ * \brief Lookup the graph node equal map for vars that are already mapped.
+ *
+ * This is an auxiliary method to check the Map<Var, Value> equality.
+ * \param lhs an lhs value.
+ *
+ * \return The corresponding rhs value if any, nullptr if not available.
+ */
+ virtual ObjectRef MapLhsToRhs(const ObjectRef& lhs) = 0;
+ /*!
+ * \brief Mark current comparison as graph node equal comparison.
+ */
+ virtual void MarkGraphNode() = 0;
+ };
+
+ using BaseValueEqual::operator();
+
+ /*! \brief default constructor */
+ SEqualReducer() = default;
+ /*!
+ * \brief Constructor with a specific handler.
+ * \param handler The equal handler for objects.
+ * \param map_free_vars Whether or not to map free variables.
+ */
+ explicit SEqualReducer(Handler* handler, bool map_free_vars)
+ : handler_(handler), map_free_vars_(map_free_vars) {}
+ /*!
+ * \brief Reduce condition to comparison of two objects.
+ * \param lhs The left operand.
+ * \param rhs The right operand.
+ * \return the immediate check result.
+ */
+ bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
+ return handler_->SEqualReduce(lhs, rhs, map_free_vars_);
+ }
+ /*!
+ * \brief Reduce condition to comparison of two definitions,
+ * where free vars can be mapped.
+ *
+ * Call this function to compare definition points such as function params
+ * and var in a let-binding.
+ *
+ * \param lhs The left operand.
+ * \param rhs The right operand.
+ * \return the immediate check result.
+ */
+ bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
+ return handler_->SEqualReduce(lhs, rhs, true);
+ }
+ /*!
+ * \brief Reduce condition to comparison of two arrays.
+ * \param lhs The left operand.
+ * \param rhs The right operand.
+ * \return the immediate check result.
+ */
+ template<typename T>
+ bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
+ // quick specialization for Array to reduce amount of recursion
+ // depth as array comparison is pretty common.
+ if (lhs.size() != rhs.size()) return false;
+ for (size_t i = 0; i < lhs.size(); ++i) {
+ if (!(operator()(lhs[i], rhs[i]))) return false;
+ }
+ return true;
+ }
+ /*!
+ * \brief Implementation for equality rule of var type objects(e.g. TypeVar,
tir::Var).
+ * \param lhs The left operand.
+ * \param rhs The right operand.
+ * \return the result.
+ */
+ bool FreeVarEqualImpl(const runtime::Object* lhs, const runtime::Object*
rhs) const {
+ // var need to be remapped, so it belongs to graph node.
+ handler_->MarkGraphNode();
+ // We only map free vars if they corresponds to the same address
+ // or map free_var option is set to be true.
+ return lhs == rhs || map_free_vars_;
+ }
+
+ /*! \return Get the internal handler. */
+ Handler* operator->() const {
+ return handler_;
+ }
+
+ private:
+ /*! \brief Internal class pointer. */
+ Handler* handler_;
+ /*! \brief Whether or not to map free vars. */
+ bool map_free_vars_;
+};
+
+} // namespace tvm
+#endif // TVM_NODE_STRUCTURAL_EQUAL_H_
diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h
index 8189b21..ea13e25 100644
--- a/include/tvm/relay/adt.h
+++ b/include/tvm/relay/adt.h
@@ -46,6 +46,7 @@ using TypeDataNode = tvm::TypeDataNode;
class PatternNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Pattern";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(PatternNode, Object);
};
@@ -74,6 +75,10 @@ class PatternWildcardNode : public PatternNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const PatternNode* other, SEqualReducer equal) const {
+ return true;
+ }
+
static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternWildcardNode, PatternNode);
};
@@ -118,6 +123,10 @@ class PatternVarNode : public PatternNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const PatternVarNode* other, SEqualReducer equal) const {
+ return equal.DefEqual(var, other->var);
+ }
+
static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternVarNode, PatternNode);
};
@@ -149,6 +158,12 @@ class PatternConstructorNode : public PatternNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const PatternConstructorNode* other, SEqualReducer equal)
const {
+ return
+ equal(constructor, other->constructor) &&
+ equal(patterns, other->patterns);
+ }
+
static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternConstructorNode, PatternNode);
};
@@ -178,6 +193,10 @@ class PatternTupleNode : public PatternNode {
v->Visit("span", &span);
}
+ bool SEqualReduce(const PatternTupleNode* other, SEqualReducer equal) const {
+ return equal(patterns, other->patterns);
+ }
+
static constexpr const char* _type_key = "relay.PatternTuple";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternTupleNode, PatternNode);
};
@@ -208,7 +227,12 @@ class ClauseNode : public Object {
v->Visit("rhs", &rhs);
}
+ bool SEqualReduce(const ClauseNode* other, SEqualReducer equal) const {
+ return equal(lhs, other->lhs) && equal(rhs, other->rhs);
+ }
+
static constexpr const char* _type_key = "relay.Clause";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
};
@@ -248,6 +272,14 @@ class MatchNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const MatchNode* other, SEqualReducer equal) const {
+ equal->MarkGraphNode();
+ return
+ equal(data, other->data) &&
+ equal(clauses, other->clauses) &&
+ equal(complete, other->complete);
+ }
+
static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
};
diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index 3acb5dd..731046e 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -26,6 +26,7 @@
#include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h>
+#include <tvm/ir/op.h>
#include <tvm/ir/module.h>
#include <string>
#include <functional>
@@ -72,6 +73,10 @@ class ConstantNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
+ return equal(data, other->data);
+ }
+
static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};
@@ -101,6 +106,16 @@ class TupleNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {
+ // specially handle empty tuple as a constant is not a graph node.
+ if (fields.size() == other->fields.size() && fields.size() == 0) {
+ return true;
+ } else {
+ equal->MarkGraphNode();
+ return equal(fields, other->fields);
+ }
+ }
+
static constexpr const char* _type_key = "relay.Tuple";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode);
};
@@ -157,6 +172,12 @@ class VarNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
+ return
+ equal(type_annotation, other->type_annotation) &&
+ equal.FreeVarEqualImpl(this, other);
+ }
+
TVM_DLL static Var make(std::string name_hint,
Type type_annotation);
@@ -238,6 +259,16 @@ class CallNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
+ // skip type_args check for primitive ops.
+ equal->MarkGraphNode();
+ return
+ equal(op, other->op) &&
+ equal(args, other->args) &&
+ equal(attrs, other->attrs) &&
+ (IsPrimitiveOp(op) || equal(type_args, other->type_args));
+ }
+
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
};
@@ -289,6 +320,14 @@ class LetNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
+ equal->MarkGraphNode();
+ return
+ equal.DefEqual(var, other->var) &&
+ equal(value, other->value) &&
+ equal(body, other->body);
+ }
+
static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
};
@@ -336,6 +375,14 @@ class IfNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const IfNode* other, SEqualReducer equal) const {
+ equal->MarkGraphNode();
+ return
+ equal(cond, other->cond) &&
+ equal(true_branch, other->true_branch) &&
+ equal(false_branch, other->false_branch);
+ }
+
static constexpr const char* _type_key = "relay.If";
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
};
@@ -369,6 +416,12 @@ class TupleGetItemNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const TupleGetItemNode* other, SEqualReducer equal) const {
+ return
+ equal(tuple, other->tuple) &&
+ equal(index, other->index);
+ }
+
static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode);
};
@@ -398,6 +451,11 @@ class RefCreateNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const RefCreateNode* other, SEqualReducer equal) const {
+ equal->MarkGraphNode();
+ return equal(value, other->value);
+ }
+
static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode);
};
@@ -426,6 +484,11 @@ class RefReadNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const RefReadNode* other, SEqualReducer equal) const {
+ equal->MarkGraphNode();
+ return equal(ref, other->ref);
+ }
+
static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode);
};
@@ -456,6 +519,13 @@ class RefWriteNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const RefWriteNode* other, SEqualReducer equal) const {
+ equal->MarkGraphNode();
+ return
+ equal(ref, other->ref) &&
+ equal(value, other->value);
+ }
+
TVM_DLL static RefWrite make(Expr ref, Expr value);
static constexpr const char* _type_key = "relay.RefWrite";
@@ -497,6 +567,7 @@ class TempExprNode : public ExprNode {
virtual Expr Realize() const = 0;
static constexpr const char* _type_key = "relay.TempExpr";
+ static constexpr const bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(TempExprNode, ExprNode);
};
diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h
index 5c5bd26..ed39caa 100644
--- a/include/tvm/relay/function.h
+++ b/include/tvm/relay/function.h
@@ -68,6 +68,17 @@ class FunctionNode : public BaseFuncNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
+ // Important to make def equal first.
+ equal->MarkGraphNode();
+ return
+ equal.DefEqual(params, other->params) &&
+ equal.DefEqual(type_params, other->type_params) &&
+ equal(ret_type, other->ret_type) &&
+ equal(attrs, other->attrs) &&
+ equal(body, other->body);
+ }
+
/*!
* \brief Return the derived function annotation of this expression.
*
diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h
index 2441ab6..17f81a2 100644
--- a/include/tvm/runtime/ndarray.h
+++ b/include/tvm/runtime/ndarray.h
@@ -65,6 +65,8 @@ class NDArray : public ObjectRef {
inline int use_count() const;
/*! \return Pointer to content of DLTensor */
inline const DLTensor* operator->() const;
+ /*! \return Whether the tensor is contiguous */
+ inline bool IsContiguous() const;
/*!
* \brief Copy data content from another array.
* \param other The source array to be copied from.
@@ -313,6 +315,26 @@ inline size_t GetDataSize(const DLTensor& arr) {
return size;
}
+/*!
+ * \brief check if a DLTensor is contiguous.
+ * \param arr The input DLTensor.
+ * \return The check result.
+ */
+inline bool IsContiguous(const DLTensor& arr) {
+ if (arr.strides == nullptr) return true;
+ int64_t expected_stride = 1;
+ for (int32_t i = arr.ndim; i != 0; --i) {
+ int32_t k = i - 1;
+ if (arr.strides[k] != expected_stride) return false;
+ expected_stride *= arr.shape[k];
+ }
+ return true;
+}
+
+inline bool NDArray::IsContiguous() const {
+ return ::tvm::runtime::IsContiguous(get_mutable()->dl_tensor);
+}
+
inline void NDArray::CopyFrom(const DLTensor* other) {
CHECK(data_ != nullptr);
CopyFromTo(other, &(get_mutable()->dl_tensor));
diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index fe5e30b..80b479d 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -211,11 +211,15 @@ class Object {
static constexpr bool _type_final = false;
static constexpr uint32_t _type_child_slots = 0;
static constexpr bool _type_child_slots_can_overflow = true;
+ // member information
+ static constexpr bool _type_has_method_visit_attrs = true;
+ static constexpr bool _type_has_method_sequal_reduce = false;
// NOTE: the following field is not type index of Object
// but was intended to be used by sub-classes as default value.
// The type index of Object is TypeIndex::kRoot
static constexpr uint32_t _type_index = TypeIndex::kDynamic;
+
// Default constructor and copy constructor
Object() {}
// Override the copy and assign constructors to do nothing.
diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index c172316..60dd455 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -150,6 +150,20 @@ class BufferNode : public Object {
v->Visit("buffer_type", &buffer_type);
}
+ bool SEqualReduce(const BufferNode* other, SEqualReducer equal) const {
+ // Use DefEqual as buffer can define variables
+ // in its semantics, skip name as name is not important.
+ return
+ equal.DefEqual(data, other->data) &&
+ equal(dtype, other->dtype) &&
+ equal.DefEqual(shape, other->shape) &&
+ equal.DefEqual(strides, other->strides) &&
+ equal.DefEqual(elem_offset, other->elem_offset) &&
+ equal(scope, other->scope) &&
+ equal(data_alignment, other->data_alignment) &&
+ equal(buffer_type, other->buffer_type);
+ }
+
/*! \return preferred index type for this buffer node */
DataType DefaultIndexType() const {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
@@ -169,6 +183,7 @@ class BufferNode : public Object {
BufferType buffer_type);
static constexpr const char* _type_key = "Buffer";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
};
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 90fef87..28e6186 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -75,6 +75,12 @@ class VarNode : public PrimExprNode {
v->Visit("type_annotation", &type_annotation);
}
+ bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
+ if (!equal(dtype, other->dtype)) return false;
+ if (!equal(type_annotation, other->type_annotation)) return false;
+ return equal.FreeVarEqualImpl(this, other);
+ }
+
static constexpr const char* _type_key = "tir.Var";
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
};
@@ -288,11 +294,20 @@ class IterVarNode : public Object {
v->Visit("thread_tag", &thread_tag);
}
+ bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
+ return
+ equal(dom, other->dom) &&
+ equal.DefEqual(var, other->var) &&
+ equal(iter_type, other->iter_type) &&
+ equal(thread_tag, other->thread_tag);
+ }
+
TVM_DLL static IterVar make(Range dom, Var var,
IterVarType iter_type,
std::string thread_tag = "");
static constexpr const char* _type_key = "IterVar";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
};
@@ -334,6 +349,10 @@ class StringImmNode : public PrimExprNode {
v->Visit("value", &value);
}
+ bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const {
+ return equal(value, other->value);
+ }
+
TVM_DLL PrimExpr static make(std::string value);
static constexpr const char* _type_key = "StringImm";
@@ -359,6 +378,10 @@ class CastNode : public PrimExprNode {
v->Visit("value", &value);
}
+ bool SEqualReduce(const CastNode* other, SEqualReducer equal) const {
+ return equal(dtype, other->dtype) && equal(value, other->value);
+ }
+
TVM_DLL static PrimExpr make(DataType t, PrimExpr v);
static constexpr const char* _type_key = "Cast";
@@ -383,6 +406,13 @@ class BinaryOpNode : public PrimExprNode {
v->Visit("b", &b);
}
+ bool SEqualReduce(const T* other, SEqualReducer equal) const {
+ return
+ equal(dtype, other->dtype) &&
+ equal(a, other->a) &&
+ equal(b, other->b);
+ }
+
static PrimExpr make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n";
@@ -475,6 +505,13 @@ class CmpOpNode : public PrimExprNode {
v->Visit("b", &b);
}
+ bool SEqualReduce(const T* other, SEqualReducer equal) const {
+ return
+ equal(dtype, other->dtype) &&
+ equal(a, other->a) &&
+ equal(b, other->b);
+ }
+
static PrimExpr make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined\n";
CHECK(b.defined()) << "ValueError: b is undefined\n";
@@ -539,6 +576,13 @@ class AndNode : public PrimExprNode {
v->Visit("b", &b);
}
+ bool SEqualReduce(const AndNode* other, SEqualReducer equal) const {
+ return
+ equal(dtype, other->dtype) &&
+ equal(a, other->a) &&
+ equal(b, other->b);
+ }
+
TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
static constexpr const char* _type_key = "And";
@@ -559,6 +603,13 @@ class OrNode : public PrimExprNode {
v->Visit("b", &b);
}
+ bool SEqualReduce(const OrNode* other, SEqualReducer equal) const {
+ return
+ equal(dtype, other->dtype) &&
+ equal(a, other->a) &&
+ equal(b, other->b);
+ }
+
TVM_DLL static PrimExpr make(PrimExpr a, PrimExpr b);
static constexpr const char* _type_key = "Or";
@@ -576,6 +627,10 @@ class NotNode : public PrimExprNode {
v->Visit("a", &a);
}
+ bool SEqualReduce(const NotNode* other, SEqualReducer equal) const {
+ return equal(dtype, other->dtype) && equal(a, other->a);
+ }
+
TVM_DLL static PrimExpr make(PrimExpr a);
static constexpr const char* _type_key = "Not";
@@ -605,6 +660,14 @@ class SelectNode : public PrimExprNode {
v->Visit("false_value", &false_value);
}
+ bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const {
+ return
+ equal(dtype, other->dtype) &&
+ equal(condition, other->condition) &&
+ equal(true_value, other->true_value) &&
+ equal(false_value, other->false_value);
+ }
+
TVM_DLL static PrimExpr make(PrimExpr condition, PrimExpr true_value,
PrimExpr false_value);
static constexpr const char* _type_key = "Select";
@@ -642,6 +705,14 @@ class LoadNode : public PrimExprNode {
v->Visit("predicate", &predicate);
}
+ bool SEqualReduce(const LoadNode* other, SEqualReducer equal) const {
+ return
+ equal(dtype, other->dtype) &&
+ equal(buffer_var, other->buffer_var) &&
+ equal(index, other->index) &&
+ equal(predicate, other->predicate);
+ }
+
TVM_DLL static PrimExpr make(DataType dtype, Var buffer_var, PrimExpr index,
PrimExpr predicate);
static constexpr const char* _type_key = "Load";
@@ -673,6 +744,14 @@ class RampNode : public PrimExprNode {
v->Visit("lanes", &lanes);
}
+ bool SEqualReduce(const RampNode* other, SEqualReducer equal) const {
+ return
+ equal(dtype, other->dtype) &&
+ equal(base, other->base) &&
+ equal(stride, other->stride) &&
+ equal(lanes, other->lanes);
+ }
+
TVM_DLL static PrimExpr make(PrimExpr base, PrimExpr stride, int lanes);
static constexpr const char* _type_key = "Ramp";
@@ -693,6 +772,13 @@ class BroadcastNode : public PrimExprNode {
v->Visit("lanes", &lanes);
}
+ bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const {
+ return
+ equal(dtype, other->dtype) &&
+ equal(value, other->value) &&
+ equal(lanes, other->lanes);
+ }
+
TVM_DLL static PrimExpr make(PrimExpr value, int lanes);
static constexpr const char* _type_key = "Broadcast";
@@ -718,6 +804,14 @@ class LetNode : public PrimExprNode {
v->Visit("body", &body);
}
+ bool SEqualReduce(const LetNode* other, SEqualReducer equal) const {
+ return
+ equal(dtype, other->dtype) &&
+ equal.DefEqual(var, other->var) &&
+ equal(value, other->value) &&
+ equal(body, other->body);
+ }
+
TVM_DLL static PrimExpr make(Var var, PrimExpr value, PrimExpr body);
static constexpr const char* _type_key = "Let";
@@ -788,12 +882,22 @@ class CallNode : public PrimExprNode {
v->Visit("value_index", &value_index);
}
+ bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {
+ return
+ equal(dtype, other->dtype) &&
+ equal(name, other->name) &&
+ equal(args, other->args) &&
+ equal(call_type, other->call_type) &&
+ equal(func, other->func) &&
+ equal(value_index, other->value_index);
+ }
+
TVM_DLL static PrimExpr make(DataType dtype,
- std::string name,
- Array<PrimExpr> args,
- CallType call_type,
- FunctionRef func = FunctionRef(),
- int value_index = 0);
+ std::string name,
+ Array<PrimExpr> args,
+ CallType call_type,
+ FunctionRef func = FunctionRef(),
+ int value_index = 0);
/*! \return Whether call node is pure. */
bool is_pure() const {
@@ -856,6 +960,13 @@ class ShuffleNode : public PrimExprNode {
v->Visit("indices", &indices);
}
+ bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const {
+ return
+ equal(dtype, other->dtype) &&
+ equal(vectors, other->vectors) &&
+ equal(indices, other->indices);
+ }
+
TVM_DLL static PrimExpr make(Array<PrimExpr> vectors, Array<PrimExpr>
indices);
TVM_DLL static PrimExpr make_concat(Array<PrimExpr> vectors);
TVM_DLL static PrimExpr make_extract_element(PrimExpr vector, int index);
@@ -918,7 +1029,16 @@ class CommReducerNode : public Object {
v->Visit("identity_element", &identity_element);
}
+ bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const {
+ return
+ equal.DefEqual(lhs, other->lhs) &&
+ equal.DefEqual(rhs, other->rhs) &&
+ equal(result, other->result) &&
+ equal(identity_element, other->identity_element);
+ }
+
static constexpr const char* _type_key = "CommReducer";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object);
};
@@ -948,10 +1068,10 @@ class ReduceNode : public PrimExprNode {
/*! \brief construct expr from op and rdom */
TVM_DLL static PrimExpr make(CommReducer combiner,
- Array<PrimExpr> src,
- Array<IterVar> rdom,
- PrimExpr condition,
- int value_index);
+ Array<PrimExpr> src,
+ Array<IterVar> rdom,
+ PrimExpr condition,
+ int value_index);
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
@@ -962,6 +1082,16 @@ class ReduceNode : public PrimExprNode {
v->Visit("value_index", &value_index);
}
+ bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const {
+ // check axis first so IterVars can define the necessary variables.
+ return
+ equal(dtype, other->dtype) &&
+ equal(axis, other->axis) &&
+ equal(combiner, other->combiner) &&
+ equal(source, other->source) &&
+ equal(condition, other->condition) &&
+ equal(value_index, other->value_index);
+ }
static constexpr const char* _type_key = "Reduce";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
};
@@ -970,6 +1100,11 @@ class ReduceNode : public PrimExprNode {
class AnyNode : public PrimExprNode {
public:
void VisitAttrs(AttrVisitor* v) {}
+
+ bool SEqualReduce(const AnyNode* other, SEqualReducer equal) const {
+ return true;
+ }
+
/*! \brief Convert to var. */
Var ToVar() const {
return Var("any_dim", DataType::Int(32));
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index 63a8630..26b643a 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -102,6 +102,16 @@ class PrimFuncNode : public BaseFuncNode {
v->Visit("_checked_type_", &checked_type_);
}
+ bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
+ // visit params and buffer_map first as they contains defs.
+ return
+ equal.DefEqual(params, other->params) &&
+ equal(buffer_map, other->buffer_map) &&
+ equal(ret_type, other->ret_type) &&
+ equal(body, other->body) &&
+ equal(attrs, other->attrs);
+ }
+
/*!
* \brief Return the derived function annotation of this function.
*
@@ -112,6 +122,7 @@ class PrimFuncNode : public BaseFuncNode {
TVM_DLL FuncType func_type_annotation() const;
static constexpr const char* _type_key = "tir.PrimFunc";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode);
};
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index a543737..d4b144d 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -38,6 +38,7 @@ namespace tir {
class StmtNode : public Object {
public:
static constexpr const char* _type_key = "Stmt";
+ static constexpr const bool _type_has_method_sequal_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(StmtNode, Object);
};
@@ -65,6 +66,13 @@ class LetStmtNode : public StmtNode {
v->Visit("body", &body);
}
+ bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const {
+ return
+ equal.DefEqual(var, other->var) &&
+ equal(value, other->value) &&
+ equal(body, other->body);
+ }
+
TVM_DLL static Stmt make(Var var, PrimExpr value, Stmt body);
static constexpr const char* _type_key = "LetStmt";
@@ -99,6 +107,14 @@ class AttrStmtNode : public StmtNode {
v->Visit("body", &body);
}
+ bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const {
+ return
+ equal(node, other->node) &&
+ equal(attr_key, other->attr_key) &&
+ equal(value, other->value) &&
+ equal(body, other->body);
+ }
+
TVM_DLL static Stmt make(ObjectRef node,
std::string type_key,
PrimExpr value,
@@ -129,6 +145,13 @@ class AssertStmtNode : public StmtNode {
v->Visit("body", &body);
}
+ bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const {
+ return
+ equal(condition, other->condition) &&
+ equal(message, other->message) &&
+ equal(body, other->body);
+ }
+
TVM_DLL static Stmt make(PrimExpr condition, PrimExpr message, Stmt body);
static constexpr const char* _type_key = "AssertStmt";
@@ -152,6 +175,13 @@ class ProducerConsumerNode : public StmtNode {
v->Visit("body", &body);
}
+ bool SEqualReduce(const ProducerConsumerNode* other, SEqualReducer equal)
const {
+ return
+ equal(func, other->func) &&
+ equal(is_producer, other->is_producer) &&
+ equal(body, other->body);
+ }
+
TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body);
static constexpr const char* _type_key = "ProducerConsumer";
@@ -194,6 +224,14 @@ class StoreNode : public StmtNode {
v->Visit("predicate", &predicate);
}
+ bool SEqualReduce(const StoreNode* other, SEqualReducer equal) const {
+ return
+ equal(buffer_var, other->buffer_var) &&
+ equal(value, other->value) &&
+ equal(index, other->index) &&
+ equal(predicate, other->predicate);
+ }
+
TVM_DLL static Stmt make(Var buffer_var,
PrimExpr value,
PrimExpr index,
@@ -224,6 +262,14 @@ class ProvideNode : public StmtNode {
v->Visit("args", &args);
}
+ bool SEqualReduce(const ProvideNode* other, SEqualReducer equal) const {
+ return
+ equal(func, other->func) &&
+ equal(value_index, other->value_index) &&
+ equal(value, other->value) &&
+ equal(args, other->args);
+ }
+
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
PrimExpr value,
@@ -261,6 +307,15 @@ class AllocateNode : public StmtNode {
v->Visit("body", &body);
}
+ bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const {
+ return
+ equal.DefEqual(buffer_var, other->buffer_var) &&
+ equal(dtype, other->dtype) &&
+ equal(extents, other->extents) &&
+ equal(condition, other->condition) &&
+ equal(body, other->body);
+ }
+
TVM_DLL static Stmt make(Var buffer_var,
DataType dtype,
Array<PrimExpr> extents,
@@ -300,6 +355,11 @@ class FreeNode : public StmtNode {
v->Visit("buffer_var", &buffer_var);
}
+ bool SEqualReduce(const FreeNode* other, SEqualReducer equal) const {
+ return
+ equal(buffer_var, other->buffer_var);
+ }
+
TVM_DLL static Stmt make(Var buffer_var);
static constexpr const char* _type_key = "Free";
@@ -341,6 +401,16 @@ class RealizeNode : public StmtNode {
PrimExpr condition,
Stmt body);
+ bool SEqualReduce(const RealizeNode* other, SEqualReducer equal) const {
+ return
+ equal(func, other->func) &&
+ equal(value_index, other->value_index) &&
+ equal(dtype, other->dtype) &&
+ equal(bounds, other->bounds) &&
+ equal(condition, other->condition) &&
+ equal(body, other->body);
+ }
+
static constexpr const char* _type_key = "Realize";
TVM_DECLARE_FINAL_OBJECT_INFO(RealizeNode, StmtNode);
};
@@ -369,6 +439,10 @@ class SeqStmtNode : public StmtNode {
v->Visit("seq", &seq);
}
+ bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const {
+ return equal(seq, other->seq);
+ }
+
static constexpr const char* _type_key = "SeqStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
};
@@ -472,6 +546,13 @@ class IfThenElseNode : public StmtNode {
v->Visit("else_case", &else_case);
}
+ bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const {
+ return
+ equal(condition, other->condition) &&
+ equal(then_case, other->then_case) &&
+ equal(else_case, other->else_case);
+ }
+
TVM_DLL static Stmt make(PrimExpr condition, Stmt then_case, Stmt else_case
= Stmt());
static constexpr const char* _type_key = "IfThenElse";
@@ -493,6 +574,10 @@ class EvaluateNode : public StmtNode {
v->Visit("value", &value);
}
+ bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const {
+ return equal(value, other->value);
+ }
+
TVM_DLL static Stmt make(PrimExpr v);
static constexpr const char* _type_key = "Evaluate";
@@ -562,6 +647,16 @@ class ForNode : public StmtNode {
v->Visit("body", &body);
}
+ bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
+ return
+ equal.DefEqual(loop_var, other->loop_var) &&
+ equal(min, other->min) &&
+ equal(extent, other->extent) &&
+ equal(for_type, other->for_type) &&
+ equal(device_api, other->device_api) &&
+ equal(body, other->body);
+ }
+
static constexpr const char* _type_key = "For";
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
};
@@ -587,6 +682,14 @@ class PrefetchNode : public StmtNode {
v->Visit("bounds", &bounds);
}
+ bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
+ return
+ equal(func, other->func) &&
+ equal(value_index, other->value_index) &&
+ equal(dtype, other->dtype) &&
+ equal(bounds, other->bounds);
+ }
+
TVM_DLL static Stmt make(FunctionRef func,
int value_index,
DataType dtype,
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 1e11446..88af05c 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -17,6 +17,7 @@
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
+from .base import structural_equal, assert_structural_equal
from .type import Type, TypeKind, PrimType, PointerType, TypeVar,
GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index 810d78f..df69a2c 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -149,3 +149,76 @@ def save_json(node):
Saved json string.
"""
return tvm.runtime._ffi_node_api.SaveJSON(node)
+
+
+def structural_equal(lhs, rhs, map_free_vars=False):
+ """Check structural equality of lhs and rhs.
+
+ The structural equality is recursively defined in the DAG of IRNodes.
+ There are two kinds of nodes:
+
+ - Graph node: a graph node in lhs can only be mapped as equal to
+ one and only one graph node in rhs.
+ - Normal node: equality is recursively defined without the restriction
+ of graph nodes.
+
+ Vars(tir::Var, TypeVar) and non-constant relay expression nodes are graph
nodes.
+ For example, it means that `%1 = %x + %y; %1 + %1` is not structurally
equal
+ to `%1 = %x + %y; %2 = %x + %y; %1 + %2` in relay.
+
+ A var-type node(e.g. tir::Var, TypeVar) can be mapped as equal to another
var
+ with the same type if one of the following condition holds:
+
+ - They appear in a same definition point(e.g. function argument).
+ - They points to the same VarNode via the same_as relation.
+ - They appear in a same usage point, and map_free_vars is set to be True.
+
+ The rules for var are used to remap variables occurs in function
+ arguments and let-bindings.
+
+ Parameters
+ ----------
+ lhs : Object
+ The left operand.
+
+ rhs : Object
+ The left operand.
+
+ map_free_vars : bool
+ Whether or not shall we map free vars that does
+ not bound to any definitions as equal to each other.
+
+ Return
+ ------
+ result : bool
+ The comparison result.
+ """
+ return tvm.runtime._ffi_node_api.StructuralEqual(
+ lhs, rhs, False, map_free_vars)
+
+
+def assert_structural_equal(lhs, rhs, map_free_vars=False):
+ """Assert lhs and rhs are structurally equal to each other.
+
+ Parameters
+ ----------
+ lhs : Object
+ The left operand.
+
+ rhs : Object
+ The left operand.
+
+ map_free_vars : bool
+ Whether or not shall we map free vars that does
+ not bound to any definitions as equal to each other.
+
+ Raises
+ ------
+ ValueError : if assertion does not hold.
+
+ See Also
+ --------
+ structural_equal
+ """
+ tvm.runtime._ffi_node_api.StructuralEqual(
+ lhs, rhs, True, map_free_vars)
diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h
index babd08a..9acc465 100644
--- a/src/ir/attr_functor.h
+++ b/src/ir/attr_functor.h
@@ -45,8 +45,8 @@ class AttrFunctor;
#define ATTR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
- [](const ObjectRef& n, TSelf* self, Args... args) { \
- return self->VisitAttr_(static_cast<const OP*>(n.get()), \
+ [](const ObjectRef& n, TSelf* self, Args... args) { \
+ return self->VisitAttr_(static_cast<const OP*>(n.get()), \
std::forward<Args>(args)...); \
}); \
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 9731a51..b07f04a 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -105,6 +105,7 @@ TVM_REGISTER_GLOBAL("ir.FloatImm")
TVM_REGISTER_NODE_TYPE(FloatImmNode);
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloatImmNode*>(node.get());
@@ -143,17 +144,14 @@ TVM_REGISTER_GLOBAL("ir.Range")
*ret = Range(args[0], args[1]);
});
+TVM_REGISTER_NODE_TYPE(RangeNode);
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RangeNode*>(node.get());
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});
-TVM_REGISTER_NODE_TYPE(ArrayNode);
-TVM_REGISTER_NODE_TYPE(MapNode);
-TVM_REGISTER_NODE_TYPE(StrMapNode);
-TVM_REGISTER_NODE_TYPE(RangeNode);
-
GlobalVar::GlobalVar(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
diff --git a/src/ir/module.cc b/src/ir/module.cc
index 4ac769b..ca85cb8 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -65,6 +65,21 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
data_ = std::move(n);
}
+
+bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer
equal) const {
+ if (functions.size() != other->functions.size()) return false;
+ for (const auto& kv : this->functions) {
+ if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
+ if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
+ }
+ if (type_definitions.size() != other->type_definitions.size()) return false;
+ for (const auto& kv : this->type_definitions) {
+ if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
+ if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return
false;
+ }
+ return true;
+}
+
bool IRModuleNode::ContainGlobalVar(const std::string& name) const {
return global_var_map_.find(name) != global_var_map_.end();
}
@@ -305,8 +320,8 @@ IRModule IRModule::FromExpr(
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = IRModule(global_funcs, type_definitions);
BaseFunc func;
- if (auto* func_node = expr.as<relay::FunctionNode>()) {
- func = GetRef<relay::Function>(func_node);
+ if (auto* func_node = expr.as<BaseFuncNode>()) {
+ func = GetRef<BaseFunc>(func_node);
} else {
func = relay::Function(
relay::FreeVars(expr), expr, Type(),
diff --git a/src/node/container.cc b/src/node/container.cc
index 25bfe9d..fc5c62a 100644
--- a/src/node/container.cc
+++ b/src/node/container.cc
@@ -21,11 +21,98 @@
* \file src/node/container.cc
*/
#include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
+#include <cstring>
namespace tvm {
+// SEQualReduce traits for runtime containers.
+struct StringObjTrait {
+ static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+ static bool SEqualReduce(const runtime::StringObj* lhs,
+ const runtime::StringObj* rhs,
+ SEqualReducer equal) {
+ if (lhs == rhs) return true;
+ if (lhs->size != rhs->size) return false;
+ if (lhs->data != rhs->data) return true;
+ return std::memcmp(lhs->data, rhs->data, lhs->size) != 0;
+ }
+};
+
+TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
+
+struct ADTObjTrait {
+ static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+ static bool SEqualReduce(const runtime::ADTObj* lhs,
+ const runtime::ADTObj* rhs,
+ SEqualReducer equal) {
+ if (lhs == rhs) return true;
+ if (lhs->tag != rhs->tag) return false;
+ if (lhs->size != rhs->size) return false;
+
+ for (uint32_t i = 0; i < lhs->size; ++i) {
+ if (!equal((*lhs)[i], (*rhs)[i])) return false;
+ }
+ return true;
+ }
+};
+
+TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);
+
+
+struct NDArrayContainerTrait {
+ static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+ static bool SEqualReduce(const runtime::NDArray::Container* lhs,
+ const runtime::NDArray::Container* rhs,
+ SEqualReducer equal) {
+ if (lhs == rhs) return true;
+
+ auto ldt = lhs->dl_tensor.dtype;
+ auto rdt = rhs->dl_tensor.dtype;
+ CHECK_EQ(lhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU
tensor";
+ CHECK_EQ(rhs->dl_tensor.ctx.device_type, kDLCPU) << "can only compare CPU
tensor";
+ CHECK(runtime::IsContiguous(lhs->dl_tensor))
+ << "Can only compare contiguous tensor";
+ CHECK(runtime::IsContiguous(rhs->dl_tensor))
+ << "Can only compare contiguous tensor";
+ if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits ==
rdt.bits) {
+ size_t data_size = runtime::GetDataSize(lhs->dl_tensor);
+ return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size)
== 0;
+ } else {
+ return false;
+ }
+ }
+};
+
+TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container,
NDArrayContainerTrait);
+
+
+struct ArrayNodeTrait {
+ static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+ static bool SEqualReduce(const ArrayNode* lhs,
+ const ArrayNode* rhs,
+ SEqualReducer equal) {
+ if (lhs->data.size() != rhs->data.size()) return false;
+ for (size_t i = 0; i < lhs->data.size(); ++i) {
+ if (!equal(lhs->data[i], rhs->data[i])) return false;
+ }
+ return true;
+ }
+};
+
+TVM_REGISTER_OBJECT_TYPE(ArrayNode);
+TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
+.set_creator([](const std::string&) -> ObjectPtr<Object> {
+ return ::tvm::runtime::make_object<ArrayNode>();
+ });
+
+
TVM_REGISTER_GLOBAL("node.Array")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<ObjectRef> data;
@@ -62,6 +149,59 @@ TVM_REGISTER_GLOBAL("node.ArraySize")
static_cast<const ArrayNode*>(ptr)->data.size());
});
+
+struct MapNodeTrait {
+ static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+ static bool SEqualReduce(const MapNode* lhs,
+ const MapNode* rhs,
+ SEqualReducer equal) {
+ if (rhs->data.size() != lhs->data.size()) return false;
+ for (const auto& kv : lhs->data) {
+ // Only allow equal checking if the keys are already mapped
+ // This resolves common use cases where we want to store
+ // Map<Var, Value> where Var is defined in the function
+ // parameters.
+ ObjectRef rhs_key = equal->MapLhsToRhs(kv.first);
+ if (!rhs_key.defined()) return false;
+ auto it = rhs->data.find(rhs_key);
+ if (it == rhs->data.end()) return false;
+ if (!equal(kv.second, it->second)) return false;
+ }
+ return true;
+ }
+};
+
+TVM_REGISTER_OBJECT_TYPE(MapNode);
+TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait)
+.set_creator([](const std::string&) -> ObjectPtr<Object> {
+ return ::tvm::runtime::make_object<MapNode>();
+ });
+
+
+struct StrMapNodeTrait {
+ static constexpr const std::nullptr_t VisitAttrs = nullptr;
+
+ static bool SEqualReduce(const StrMapNode* lhs,
+ const StrMapNode* rhs,
+ SEqualReducer equal) {
+ if (rhs->data.size() != lhs->data.size()) return false;
+ for (const auto& kv : lhs->data) {
+ auto it = rhs->data.find(kv.first);
+ if (it == rhs->data.end()) return false;
+ if (!equal(kv.second, it->second)) return false;
+ }
+ return true;
+ }
+};
+
+TVM_REGISTER_OBJECT_TYPE(StrMapNode);
+TVM_REGISTER_REFLECTION_VTABLE(StrMapNode, StrMapNodeTrait)
+.set_creator([](const std::string&) -> ObjectPtr<Object> {
+ return ::tvm::runtime::make_object<StrMapNode>();
+ });
+
+
TVM_REGISTER_GLOBAL("node.Map")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
diff --git a/src/node/reflection.cc b/src/node/reflection.cc
index 183079f..824874f 100644
--- a/src/node/reflection.cc
+++ b/src/node/reflection.cc
@@ -180,7 +180,7 @@ ObjectPtr<Object>
ReflectionVTable::CreateInitObject(const std::string& type_key,
const std::string& global_key) const {
uint32_t tindex = Object::TypeKey2Index(type_key);
- if (tindex >= fvisit_attrs_.size() || fvisit_attrs_[tindex] == nullptr) {
+ if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << type_key
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc
new file mode 100644
index 0000000..23dfe15
--- /dev/null
+++ b/src/node/structural_equal.cc
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/node/structural_equal.cc
+ */
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/reflection.h>
+#include <tvm/node/functor.h>
+#include <tvm/node/node.h>
+#include <tvm/runtime/registry.h>
+
+#include <unordered_map>
+
+namespace tvm {
+
+// Define the dispatch functio here since primary user is in this file.
+bool ReflectionVTable::
+SEqualReduce(const Object* self, const Object* other, SEqualReducer equal)
const {
+ uint32_t tindex = self->type_index();
+ if (tindex >= fsequal_.size() || fsequal_[tindex] == nullptr) {
+ LOG(FATAL) << "TypeError: SEqualReduce of " << self->GetTypeKey()
+ << " is not registered via TVM_REGISTER_NODE_TYPE";
+ }
+ return fsequal_[tindex](self, other, equal);
+}
+
+/*!
+ * \brief A non recursive stack based SEqual handler that can remaps vars.
+ *
+ * This handler pushs the Object equality cases into a stack, and
+ * traverses the stack to expand the necessary children that need to be
checked.
+ *
+ * The order of SEqual being called is the same as the order as if we
+ * eagerly do recursive calls in SEqualReduce.
+ */
+class RemapVarSEqualHandler :
+ public SEqualReducer::Handler {
+ public:
+ explicit RemapVarSEqualHandler(bool assert_mode)
+ : assert_mode_(assert_mode) {}
+
+ bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool
map_free_vars) final {
+ // We cannot use check lhs.same_as(rhs) to check equality.
+ // if we choose to enable var remapping.
+ //
+ // Counter example below (%x, %y) are shared vars
+ // between the two functions(possibly before/after rewriting).
+ //
+ // - function0: fn (%x, %y) { %x + %y }
+ // - function1. fn (%y, %x) { %x + %y }
+ //
+ // Because we choose to enable var remapping,
+ // %x is mapped to %y, and %y is mapped to %x,
+ // the body of the function no longer means the same thing.
+ //
+ // Take away: We can either choose only compare Var by address,
+ // in which case we can use same_as for quick checking,
+ // or we have to run deep comparison and avoid to use same_as checks.
+ auto run = [=]() {
+ if (!lhs.defined() && !rhs.defined()) return true;
+ if (!lhs.defined() && rhs.defined()) return false;
+ if (!rhs.defined() && lhs.defined()) return false;
+ if (lhs->type_index() != rhs->type_index()) return false;
+ auto it = equal_map_lhs_.find(lhs);
+ if (it != equal_map_lhs_.end()) {
+ return it->second.same_as(rhs);
+ }
+ if (equal_map_rhs_.count(rhs)) return false;
+ // need to push to pending tasks in this case
+ pending_tasks_.emplace_back(Task(lhs, rhs, map_free_vars));
+ return true;
+ };
+ return CheckResult(run(), lhs, rhs);
+ }
+
+ void MarkGraphNode() final {
+ // need to push to pending tasks in this case
+ CHECK(!allow_push_to_stack_ && !task_stack_.empty());
+ task_stack_.back().graph_equal = true;
+ }
+
+ ObjectRef MapLhsToRhs(const ObjectRef& lhs) final {
+ auto it = equal_map_lhs_.find(lhs);
+ if (it != equal_map_lhs_.end()) return it->second;
+ return ObjectRef(nullptr);
+ }
+
+ // Function that implements actual equality check.
+ bool Equal(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
+ task_stack_.clear();
+ pending_tasks_.clear();
+ equal_map_lhs_.clear();
+ equal_map_rhs_.clear();
+ if (!SEqualReduce(lhs, rhs, map_free_vars)) return false;
+ CHECK_EQ(pending_tasks_.size(), 1U);
+ CHECK(allow_push_to_stack_);
+ task_stack_.emplace_back(std::move(pending_tasks_.back()));
+ pending_tasks_.clear();
+ return RunTasks();
+ }
+
+ protected:
+ // Check the result.
+ bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
+ if (assert_mode_ && !result) {
+ LOG(FATAL)
+ << "ValueError: StructuralEqual check failed, caused by\n"
+ << "lhs = " << lhs << "\nrhs = " << rhs;
+ }
+ return result;
+ }
+ /*!
+ * \brief Run tasks until the stack reaches the stack begin
+ * \param stack_begin The expected beginning of the stack.
+ * \return The checks we encountered throughout the process.
+ */
+ bool RunTasks() {
+ while (task_stack_.size() != 0) {
+ // Caution: entry becomes invalid when the stack changes
+ auto& entry = task_stack_.back();
+
+ if (entry.children_expanded) {
+ // When all the children has expanded and visited.
+ // This means all the condition checks for
+ // the current entry has been passed
+ // We can safely mark lhs and rhs as equal to each other.
+ auto it = equal_map_lhs_.find(entry.lhs);
+ if (it != equal_map_lhs_.end()) {
+ CHECK(it->second.same_as(entry.rhs));
+ }
+ // create the map if the quality is graph equal.
+ if (entry.graph_equal) {
+ equal_map_lhs_[entry.lhs] = entry.rhs;
+ equal_map_rhs_[entry.rhs] = entry.lhs;
+ }
+ task_stack_.pop_back();
+ } else {
+ // mark before expand
+ // Important: because entry becomes invalid when stack changes.
+ entry.children_expanded = true;
+ // Expand the objects
+ // The SEqual of the object can call into this->SEqualReduce
+ // which populates the pending tasks.
+ CHECK_EQ(pending_tasks_.size(), 0U);
+ allow_push_to_stack_ = false;
+ if (!DispatchSEqualReduce(entry.lhs, entry.rhs, entry.map_free_vars))
return false;
+ allow_push_to_stack_ = true;
+ // Push pending tasks in reverse order, so earlier tasks get to
+ // expand first in the stack
+ while (pending_tasks_.size() != 0) {
+ task_stack_.emplace_back(std::move(pending_tasks_.back()));
+ pending_tasks_.pop_back();
+ }
+ }
+ }
+ return true;
+ }
+
+ // The default equal as registered in the structural equal vtable.
+ bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool
map_free_vars) {
+ auto compute = [=]() {
+ CHECK(lhs.defined() &&
+ rhs.defined() &&
+ lhs->type_index() == rhs->type_index());
+ // skip entries that already have equality maps.
+ auto it = equal_map_lhs_.find(lhs);
+ if (it != equal_map_lhs_.end()) {
+ return it->second.same_as(rhs);
+ }
+ if (equal_map_rhs_.count(rhs)) return false;
+ // Run reduce check for free nodes.
+ return vtable_->SEqualReduce(lhs.get(), rhs.get(), SEqualReducer(this,
map_free_vars));
+ };
+ return CheckResult(compute(), lhs, rhs);
+ }
+
+ private:
+ /*! \brief Pending reduce tasks. */
+ struct Task {
+ /*! \brief The lhs operand to be compared. */
+ ObjectRef lhs;
+ /*! \brief The rhs operand to be compared. */
+ ObjectRef rhs;
+ /*! \brief The map free var argument. */
+ bool map_free_vars;
+ /*! \brief Whether the children has been expanded via SEqualReduce */
+ bool children_expanded{false};
+ /*! \brief whether the task is about graph equality(need remap). */
+ bool graph_equal{false};
+
+ Task() = default;
+ Task(ObjectRef lhs, ObjectRef rhs, bool map_free_vars)
+ : lhs(lhs), rhs(rhs), map_free_vars(map_free_vars) {}
+ };
+ // list of pending tasks to be pushed to the stack.
+ std::vector<Task> pending_tasks_;
+ // Internal task stack to executed the task.
+ std::vector<Task> task_stack_;
+ // Whether we allow push to stack.
+ bool allow_push_to_stack_{true};
+ // If in assert mode, must return true, and will throw error otherwise.
+ bool assert_mode_{false};
+ // reflection vtable
+ ReflectionVTable* vtable_ = ReflectionVTable::Global();
+ // map from lhs to rhs
+ std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual>
equal_map_lhs_;
+ // map from rhs to lhs
+ std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual>
equal_map_rhs_;
+};
+
+
+TVM_REGISTER_GLOBAL("node.StructuralEqual")
+.set_body_typed([](const ObjectRef& lhs,
+ const ObjectRef& rhs,
+ bool assert_mode,
+ bool map_free_vars) {
+ return RemapVarSEqualHandler(assert_mode).Equal(lhs, rhs, map_free_vars);
+});
+
+bool StructuralEqual::operator()(const ObjectRef& lhs,
+ const ObjectRef& rhs) const {
+ return RemapVarSEqualHandler(false).Equal(lhs, rhs, false);
+}
+
+} // namespace tvm
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 07759b3..bee0256 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -81,7 +81,8 @@ TVM_REGISTER_GLOBAL("tir.Var")
TVM_REGISTER_GLOBAL("tir.SizeVar")
.set_body_typed([](std::string s, DataType t) {
return SizeVar(s, t);
- });
+});
+
IterVar IterVarNode::make(Range dom,
Var var,
@@ -132,6 +133,7 @@ PrimExpr StringImmNode::make(std::string value) {
TVM_REGISTER_GLOBAL("tir.StringImm")
.set_body_typed(StringImmNode::make);
+
PrimExpr CastNode::make(DataType t, PrimExpr value) {
CHECK(value.defined());
CHECK_EQ(t.lanes(), value.dtype().lanes());
@@ -141,6 +143,7 @@ PrimExpr CastNode::make(DataType t, PrimExpr value) {
return PrimExpr(node);
}
+
PrimExpr AndNode::make(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
@@ -169,6 +172,7 @@ PrimExpr OrNode::make(PrimExpr a, PrimExpr b) {
return PrimExpr(node);
}
+
PrimExpr NotNode::make(PrimExpr a) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(a.dtype().is_bool());
@@ -179,6 +183,8 @@ PrimExpr NotNode::make(PrimExpr a) {
return PrimExpr(node);
}
+
+
PrimExpr SelectNode::make(PrimExpr condition, PrimExpr true_value, PrimExpr
false_value) {
CHECK(condition.defined()) << "ValueError: condition is undefined";
CHECK(true_value.defined()) << "ValueError: true_value is undefined";
@@ -270,11 +276,11 @@ bool CallNode::is_vectorizable() const {
}
PrimExpr CallNode::make(DataType dtype,
- std::string name,
- Array<PrimExpr> args,
- CallType call_type,
- FunctionRef func,
- int value_index) {
+ std::string name,
+ Array<PrimExpr> args,
+ CallType call_type,
+ FunctionRef func,
+ int value_index) {
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args[i].defined());
}
diff --git a/tests/python/frontend/tensorflow/test_forward.py
b/tests/python/frontend/tensorflow/test_forward.py
index 9d875c1..b1efe4a 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -1114,7 +1114,7 @@ def test_read_variable_op():
num_output=len(out_name))
for i in range(len(tf_output)):
tvm.testing.assert_allclose(
- tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
+ tf_output[i], tvm_output[i], atol=1e-4, rtol=1e-5)
sess.close()
diff --git a/tests/python/relay/test_ir_parser.py
b/tests/python/relay/test_ir_parser.py
index fbe5213..9e62491 100644
--- a/tests/python/relay/test_ir_parser.py
+++ b/tests/python/relay/test_ir_parser.py
@@ -17,8 +17,6 @@
import tvm
from tvm import te
from tvm import relay
-from tvm.relay.analysis import graph_equal, assert_graph_equal
-from tvm.relay.analysis import alpha_equal, assert_alpha_equal
import pytest
from numpy import isclose
from typing import Union
@@ -69,6 +67,13 @@ type List[A] {
}
"""
+def assert_graph_equal(lhs, rhs):
+ tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars=True)
+
+def graph_equal(lhs, rhs):
+ return tvm.ir.structural_equal(lhs, rhs, map_free_vars=True)
+
+
def roundtrip(expr):
x = relay.fromtext(expr.astext())
assert_graph_equal(x, expr)
@@ -86,6 +91,12 @@ def parses_as(code, expr):
result = graph_equal(parsed, expr)
return result
+
+def assert_parses_as(code, expr):
+ parsed = parse_text(code)
+ assert_graph_equal(parsed, expr)
+
+
def get_scalar(x):
# type: (relay.Constant) -> (Union[float, int, bool])
return x.data.asnumpy().item()
@@ -102,7 +113,7 @@ UNIT = relay.Tuple([])
def test_comments():
- assert parses_as(
+ assert_parses_as(
"""
// This is a line comment!
()
@@ -110,7 +121,7 @@ def test_comments():
UNIT
)
- assert parses_as(
+ assert_parses_as(
"""
/* This is a block comment!
This is still a block comment!
@@ -120,7 +131,7 @@ def test_comments():
UNIT
)
- assert parses_as(
+ assert_parses_as(
"""
/* This is a block comment!
/*Block comment is recursive!*/
@@ -172,7 +183,7 @@ def test_negative():
def test_bin_op():
for bin_op in BINARY_OPS.keys():
- assert parses_as(
+ assert_parses_as(
"1 {} 1".format(bin_op),
BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1))
)
@@ -213,7 +224,7 @@ def test_vars():
def test_let():
- assert parses_as(
+ assert_parses_as(
"let %x = 1; ()",
relay.Let(
X,
@@ -222,7 +233,7 @@ def test_let():
)
)
- assert parses_as(
+ assert_parses_as(
"""
let %x = 1;
let %y = 2;
@@ -241,7 +252,7 @@ def test_let():
def test_seq():
- assert parses_as(
+ assert_parses_as(
"();; ()",
relay.Let(
_,
@@ -249,7 +260,7 @@ def test_seq():
UNIT)
)
- assert parses_as(
+ assert_parses_as(
"let %_ = 1; ()",
relay.Let(
X,
@@ -261,14 +272,10 @@ def test_seq():
def test_graph():
code = "%0 = (); %1 = 1; (%0, %0, %1)"
- assert parses_as(
+ assert_parses_as(
code,
relay.Tuple([UNIT, UNIT, relay.const(1)])
)
- assert not parses_as(
- code,
- relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)])
- )
@raises_parse_error
@@ -287,18 +294,18 @@ def test_let_op():
def test_tuple():
- assert parses_as("()", relay.Tuple([]))
+ assert_parses_as("()", relay.Tuple([]))
- assert parses_as("(0,)", relay.Tuple([relay.const(0)]))
+ assert_parses_as("(0,)", relay.Tuple([relay.const(0)]))
- assert parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)]))
+ assert_parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)]))
- assert parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1),
relay.const(2)]))
+ assert_parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1),
relay.const(2)]))
def test_func():
# 0 args
- assert parses_as(
+ assert_parses_as(
"fn () { 0 }",
relay.Function(
[],
@@ -309,7 +316,7 @@ def test_func():
)
# 1 arg
- assert parses_as(
+ assert_parses_as(
"fn (%x) { %x }",
relay.Function(
[X],
@@ -320,7 +327,7 @@ def test_func():
)
# 2 args
- assert parses_as(
+ assert_parses_as(
"fn (%x, %y) { %x + %y }",
relay.Function(
[X, Y],
@@ -331,7 +338,7 @@ def test_func():
)
# annotations
- assert parses_as(
+ assert_parses_as(
"fn (%x: int32) -> int32 { %x }",
relay.Function(
[X_ANNO],
@@ -342,7 +349,7 @@ def test_func():
)
# attributes
- assert parses_as(
+ assert_parses_as(
"fn (n=5) { () }",
relay.Function([], UNIT, None, None, tvm.ir.make_node("DictAttrs",
n=relay.const(5)))
)
@@ -370,7 +377,7 @@ def test_recursive_call():
def test_ifelse():
- assert parses_as(
+ assert_parses_as(
"""
if (True) {
0
@@ -403,7 +410,7 @@ def test_ifelse_scope():
def test_call():
# select right function to call: simple ident case
id_func = relay.Var("id")
- assert parses_as(
+ assert_parses_as(
"""
let %id = fn (%x) { %x };
10 * %id(10)
@@ -417,7 +424,7 @@ def test_call():
# 0 args
constant = relay.Var("constant")
- assert parses_as(
+ assert_parses_as(
"""
let %constant = fn () { 0 };
%constant()
@@ -431,7 +438,7 @@ def test_call():
# 1 arg
id_var = relay.Var("id")
- assert parses_as(
+ assert_parses_as(
"""
let %id = fn (%x) { %x };
%id(1)
@@ -445,7 +452,7 @@ def test_call():
# 2 args
multiply = relay.Var("multiply")
- assert parses_as(
+ assert_parses_as(
"""
let %multiply = fn (%x, %y) { %x * %y };
%multiply(0, 0)
@@ -463,7 +470,7 @@ def test_call():
)
# anonymous function
- assert parses_as(
+ assert_parses_as(
"""
(fn (%x) { %x })(0)
""",
@@ -483,7 +490,7 @@ def test_call():
# TODO(@jmp): re-enable after sequence parsing improvements
# curried function
# curried_mult = relay.Var("curried_mult")
- # assert parses_as(
+ # assert_parses_as(
# """
# let %curried_mult =
# fn (%x) {
@@ -516,7 +523,7 @@ def test_call():
# )
# op
- assert parses_as(
+ assert_parses_as(
"abs(1)",
relay.Call(relay.op.get("abs"), [relay.const(1)], None, None)
)
@@ -525,7 +532,7 @@ def test_call():
def test_incomplete_type():
- assert parses_as(
+ assert_parses_as(
"let %_ : _ = (); ()",
relay.Let(
_,
@@ -541,7 +548,7 @@ def test_builtin_types():
def test_tensor_type():
- assert parses_as(
+ assert_parses_as(
"let %_ : Tensor[(), float32] = (); ()",
relay.Let(
relay.Var("_", relay.TensorType((), "float32")),
@@ -550,7 +557,7 @@ def test_tensor_type():
)
)
- assert parses_as(
+ assert_parses_as(
"let %_ : Tensor[(1), float32] = (); ()",
relay.Let(
relay.Var("_", relay.TensorType((1,), "float32")),
@@ -559,7 +566,7 @@ def test_tensor_type():
)
)
- assert parses_as(
+ assert_parses_as(
"let %_ : Tensor[(1, 1), float32] = (); ()",
relay.Let(
relay.Var("_", relay.TensorType((1, 1), "float32")),
@@ -570,7 +577,7 @@ def test_tensor_type():
def test_function_type():
- assert parses_as(
+ assert_parses_as(
"""
let %_: fn () -> int32 = fn () -> int32 { 0 }; ()
""",
@@ -581,7 +588,7 @@ def test_function_type():
)
)
- assert parses_as(
+ assert_parses_as(
"""
let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; ()
""",
@@ -592,7 +599,7 @@ def test_function_type():
)
)
- assert parses_as(
+ assert_parses_as(
"""
let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) ->
int32 { 0 }; ()
""",
@@ -605,7 +612,7 @@ def test_function_type():
def test_tuple_type():
- assert parses_as(
+ assert_parses_as(
"""
let %_: () = (); ()
""",
@@ -616,7 +623,7 @@ def test_tuple_type():
)
)
- assert parses_as(
+ assert_parses_as(
"""
let %_: (int32,) = (0,); ()
""",
@@ -627,7 +634,7 @@ def test_tuple_type():
)
)
- assert parses_as(
+ assert_parses_as(
"""
let %_: (int32, int32) = (0, 1); ()
""",
@@ -648,7 +655,7 @@ def test_adt_defn():
[],
[relay.Constructor("Nil", [], glob_typ_var)])
mod[glob_typ_var] = prog
- assert parses_as(
+ assert_parses_as(
"""
type Ayy { Nil }
""",
@@ -662,7 +669,7 @@ def test_empty_adt_defn():
glob_typ_var = relay.GlobalTypeVar("Ayy")
prog = relay.TypeData(glob_typ_var, [], [])
mod[glob_typ_var] = prog
- assert parses_as(
+ assert_parses_as(
"""
type Ayy { }
""",
@@ -683,7 +690,7 @@ def test_multiple_cons_defn():
relay.Constructor("Nil", [], list_var),
])
mod[list_var] = prog
- assert parses_as(LIST_DEFN, mod)
+ assert_parses_as(LIST_DEFN, mod)
def test_multiple_type_param_defn():
@@ -699,7 +706,7 @@ def test_multiple_type_param_defn():
])
mod = tvm.IRModule()
mod[glob_typ_var] = prog
- assert parses_as(
+ assert_parses_as(
"""
type Either[A, B] {
Left(A),
@@ -755,7 +762,7 @@ def test_match():
)
mod[length_var] = length_func
- assert parses_as(
+ assert_parses_as(
"""
%s
@@ -796,7 +803,7 @@ def test_adt_cons_expr():
)
mod[make_singleton_var] = make_singleton_func
- assert parses_as(
+ assert_parses_as(
"""
%s
@@ -861,7 +868,7 @@ def test_extern_adt_defn():
extern_def = relay.TypeData(extern_var, [typ_var], [])
mod[extern_var] = extern_def
- assert parses_as(
+ assert_parses_as(
"""
extern type T[A]
""",
@@ -872,6 +879,7 @@ def test_import_grad():
mod.import_from_std("gradient.rly")
if __name__ == "__main__":
+ test_graph()
test_comments()
test_int_literal()
test_float_literal()
@@ -882,7 +890,6 @@ if __name__ == "__main__":
test_op_assoc()
test_let()
test_seq()
- test_graph()
test_tuple()
test_func()
test_defn()
@@ -905,4 +912,4 @@ if __name__ == "__main__":
test_duplicate_adt_cons_defn()
test_duplicate_global_var()
test_extern_adt_defn()
- test_import_grad()
\ No newline at end of file
+ test_import_grad()
diff --git a/tests/python/relay/test_pass_alpha_equal.py
b/tests/python/relay/test_ir_structural_equal.py
similarity index 78%
rename from tests/python/relay/test_pass_alpha_equal.py
rename to tests/python/relay/test_ir_structural_equal.py
index 411906d..5881ab9 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_ir_structural_equal.py
@@ -21,23 +21,24 @@ from tvm import relay
from tvm.relay import analysis
from tvm.relay.testing import run_opt_pass
-def alpha_equal(x, y):
+def sequal(x, y):
"""
Wrapper around alpha equality which ensures that
the hash function respects equality.
"""
- return analysis.alpha_equal(x, y) and analysis.structural_hash(x) ==
analysis.structural_hash(y)
+ return (tvm.ir.structural_equal(x, y) and
+ analysis.structural_hash(x) == analysis.structural_hash(y))
-def alpha_equal_commutative(x, y):
+def sequal_commutative(x, y):
"""
Check for commutative property of equality
"""
- xy = analysis.alpha_equal(x, y)
- yx = analysis.alpha_equal(y, x)
+ xy = tvm.ir.structural_equal(x, y)
+ yx = tvm.ir.structural_equal(y, x)
assert xy == yx
return xy
-def test_tensor_type_alpha_equal():
+def test_tensor_type_sequal():
t1 = relay.TensorType((3, 4), "float32")
t2 = relay.TensorType((3, 4), "float32")
t3 = relay.TensorType((3, 4, 5), "float32")
@@ -49,7 +50,7 @@ def test_tensor_type_alpha_equal():
assert t1 == t2
-def test_incomplete_type_alpha_equal():
+def test_incomplete_type_sequal():
t1 = relay.IncompleteType(relay.TypeKind.ShapeVar)
t2 = relay.IncompleteType(relay.TypeKind.Type)
t3 = relay.IncompleteType(relay.TypeKind.Type)
@@ -61,7 +62,7 @@ def test_incomplete_type_alpha_equal():
assert t2 != t3
-def test_type_param_alpha_equal():
+def test_type_param_sequal():
t1 = relay.TypeVar("v1", relay.TypeKind.Type)
t2 = relay.TypeVar("v2", relay.TypeKind.ShapeVar)
t3 = relay.TypeVar("v3", relay.TypeKind.Type)
@@ -83,7 +84,7 @@ def test_type_param_alpha_equal():
assert ft1 != ft3 # kinds still do not match
-def test_func_type_alpha_equal():
+def test_func_type_sequal():
t1 = relay.TensorType((1, 2), "float32")
t2 = relay.TensorType((1, 2, 3), "float32")
@@ -143,7 +144,7 @@ def test_func_type_alpha_equal():
assert ft != more_rels
-def test_tuple_type_alpha_equal():
+def test_tuple_type_sequal():
t1 = relay.TensorType((1, 2, 3), "float32")
t2 = relay.TensorType((1, 2, 3, 4), "float32")
tp1 = relay.TypeVar("v1", relay.TypeKind.Type)
@@ -161,7 +162,7 @@ def test_tuple_type_alpha_equal():
assert tup1 != tup4
-def test_type_relation_alpha_equal():
+def test_type_relation_sequal():
t1 = relay.TensorType((1, 2), "float32")
t2 = relay.TensorType((1, 2, 3), "float32")
t3 = relay.TensorType((1, 2, 3, 4), "float32")
@@ -197,7 +198,7 @@ def test_type_relation_alpha_equal():
assert bigger != diff_num_inputs
-def test_type_call_alpha_equal():
+def test_type_call_sequal():
h1 = relay.GlobalTypeVar("h1")
h2 = relay.GlobalTypeVar("h2")
t1 = relay.TensorType((1, 2), "float32")
@@ -221,49 +222,49 @@ def test_type_call_alpha_equal():
assert tc != different_order_args
-def test_constant_alpha_equal():
+def test_constant_sequal():
x = relay.const(1)
y = relay.const(2)
- assert alpha_equal(x, x)
- assert not alpha_equal(x, y)
- assert alpha_equal(x, relay.const(1))
+ assert sequal(x, x)
+ assert not sequal(x, y)
+ assert sequal(x, relay.const(1))
-def test_type_node_alpha_equal():
+def test_type_node_sequal():
v1 = relay.TypeVar('v1', 6)
v2 = relay.TypeVar('v2', 6)
- assert not alpha_equal(v1, v2)
+ assert not sequal(v1, v2)
v1 = relay.TypeVar('v1', 0)
v2 = relay.TypeVar('v2', 6)
- assert not alpha_equal(v1, v2)
+ assert not sequal(v1, v2)
- assert alpha_equal_commutative(v1, v1)
+ assert sequal_commutative(v1, v1)
-def test_type_node_incompatible_alpha_equal():
+def test_type_node_incompatible_sequal():
v1 = relay.TypeVar('v1', 6)
v2 = relay.Var("v2")
- assert not alpha_equal_commutative(v1, v2)
+ assert not sequal_commutative(v1, v2)
-def test_expr_node_incompatible_alpha_equal():
+def test_expr_node_incompatible_sequal():
v1 = relay.Var("v1")
v2 = relay.PatternVar(relay.Var("v2"))
- assert not alpha_equal_commutative(v1, v2)
+ assert not sequal_commutative(v1, v2)
-def test_var_alpha_equal():
+def test_var_sequal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
# normally only pointer equality
- assert alpha_equal(v1, v1)
- assert not alpha_equal(v1, v2)
+ assert sequal(v1, v1)
+ assert not sequal(v1, v2)
# let node allows for setting the eq_map
l1 = relay.Let(v1, relay.const(1), v1)
l2 = relay.Let(v2, relay.const(1), v2)
l3 = relay.Let(v1, relay.const(1), v2)
- assert alpha_equal(l1, l2)
- assert not alpha_equal(l1, l3)
+ assert sequal(l1, l2)
+ assert not sequal(l1, l3)
# type annotations
tt1 = relay.TensorType([], "int32")
@@ -278,34 +279,34 @@ def test_var_alpha_equal():
l6 = relay.Let(v5, relay.const(1), v5)
# same annotations
- assert alpha_equal(l4, l5)
+ assert sequal(l4, l5)
# different annotations
- assert not alpha_equal(l4, l6)
+ assert not sequal(l4, l6)
# one null annotation
- assert not alpha_equal(l1, l4)
+ assert not sequal(l1, l4)
-def test_global_var_alpha_equal():
+def test_global_var_sequal():
v1 = relay.GlobalVar("v1")
v2 = relay.GlobalVar("v2")
# only pointer equality suffices (smoke test)
- assert alpha_equal(v1, v1)
- assert not alpha_equal(v1, v2)
+ assert sequal(v1, v1)
+ assert not sequal(v1, v2)
-def test_tuple_alpha_equal():
+def test_tuple_sequal():
v0 = relay.Var("v0")
v1 = relay.Var("v1")
v2 = relay.Var("v2")
# unit value is a valid tuple
- assert alpha_equal(relay.Tuple([]), relay.Tuple([]))
+ assert sequal(relay.Tuple([]), relay.Tuple([]))
tup = relay.Tuple([v0, relay.const(2), relay.const(3),
relay.Tuple([relay.const(4)])])
same = relay.Tuple([v0, relay.const(2), relay.const(3),
relay.Tuple([relay.const(4)])])
- assert alpha_equal(tup, same)
+ assert sequal(tup, same)
# use the eq_map
@@ -315,33 +316,33 @@ def test_tuple_alpha_equal():
relay.Tuple([relay.const(4)])]),
v2)
- assert alpha_equal(let_tup, let_mapped)
+ assert sequal(let_tup, let_mapped)
more_fields = relay.Tuple([v1, relay.const(2), relay.const(3),
relay.Tuple([relay.const(4)]), v2])
- assert not alpha_equal(tup, more_fields)
+ assert not sequal(tup, more_fields)
fewer_fields = relay.Tuple([v1, relay.const(2), relay.const(3)])
- assert not alpha_equal(tup, fewer_fields)
+ assert not sequal(tup, fewer_fields)
different_end = relay.Tuple([v1, relay.const(2), relay.const(3),
relay.Tuple([relay.const(5)])])
- assert not alpha_equal(tup, different_end)
+ assert not sequal(tup, different_end)
different_start = relay.Tuple([v2, relay.const(2), relay.const(3),
relay.Tuple([relay.const(4)])])
- assert not alpha_equal(tup, different_start)
+ assert not sequal(tup, different_start)
longer_at_end = relay.Tuple([v1, relay.const(2), relay.const(3),
relay.Tuple([relay.const(4),
relay.const(5)])])
- assert not alpha_equal(tup, longer_at_end)
+ assert not sequal(tup, longer_at_end)
-def test_tuple_get_item_alpha_equal():
+def test_tuple_get_item_sequal():
x = relay.Var('x')
y = relay.Var('y')
- assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
- assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
- assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
+ assert not sequal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
+ assert not sequal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
+ assert sequal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
def test_function_attr():
@@ -364,10 +365,10 @@ def test_function_attr():
q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10)
func1 = func1.with_attr("FuncName", tvm.tir.StringImm("b"))
- assert not alpha_equal(func0, func1)
+ assert not sequal(func0, func1)
-def test_function_alpha_equal():
+def test_function_sequal():
tt1 = relay.TensorType((1, 2, 3), "float32")
tt2 = relay.TensorType((4, 5, 6), "int8")
tt3 = relay.TupleType([tt1, tt2])
@@ -389,58 +390,58 @@ def test_function_alpha_equal():
func = relay.Function([v1, v2], v1,
tt2, basic_tps)
mapped = relay.Function(basic_args, basic_args[0], tt2, basic_tps)
- assert alpha_equal(func, mapped)
+ assert sequal(func, mapped)
fewer_params = relay.Function([relay.Var("v4", tt2)], v4, tt2, basic_tps)
- assert not alpha_equal(func, fewer_params)
+ assert not sequal(func, fewer_params)
more_params = relay.Function([relay.Var("v3", tt1),
relay.Var("v4", tt2),
relay.Var("v2", tt2)], v4, tt2, basic_tps)
- assert not alpha_equal(func, more_params)
+ assert not sequal(func, more_params)
params_unordered = relay.Function([v2, v1], v1,
tt2, basic_tps)
- assert not alpha_equal(func, params_unordered)
+ assert not sequal(func, params_unordered)
params_mismatch = relay.Function([v1, v3], v1,
tt2, basic_tps)
- assert not alpha_equal(func, params_mismatch)
+ assert not sequal(func, params_mismatch)
# also would not typecheck
ret_type_mismatch = relay.Function(basic_args, v4, tt1, basic_tps)
- assert not alpha_equal(func, ret_type_mismatch)
+ assert not sequal(func, ret_type_mismatch)
# also mis-typed
different_body = relay.Function(basic_args, v3, tt2, basic_tps)
- assert not alpha_equal(func, different_body)
+ assert not sequal(func, different_body)
fewer_type_params = relay.Function(basic_args, v4, tt2, [tp1])
- assert not alpha_equal(func, fewer_type_params)
+ assert not sequal(func, fewer_type_params)
more_type_params = relay.Function(basic_args, v4, tt2, [tp1, tp2, tp3])
- assert not alpha_equal(func, more_type_params)
+ assert not sequal(func, more_type_params)
type_params_unordered = relay.Function(basic_args, v4, tt2, [tp2, tp1])
- assert not alpha_equal(func, type_params_unordered)
+ assert not sequal(func, type_params_unordered)
different_type_params = relay.Function(basic_args, v4, tt2, [tp3, tp4])
- assert not alpha_equal(func, different_type_params)
+ assert not sequal(func, different_type_params)
# a well-typed example that also differs in body, ret type, and type params
tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3)
- assert not alpha_equal(func, tupled_example)
+ assert not sequal(func, tupled_example)
# nullable
no_ret_type = relay.Function(basic_args, v4, None, [tp1, tp2])
# both null
- assert alpha_equal(no_ret_type, no_ret_type)
+ assert sequal(no_ret_type, no_ret_type)
# one null
- assert not alpha_equal(func, no_ret_type)
- assert not alpha_equal(no_ret_type, func)
+ assert not sequal(func, no_ret_type)
+ assert not sequal(no_ret_type, func)
-def test_call_alpha_equal():
+def test_call_sequal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
@@ -458,43 +459,43 @@ def test_call_alpha_equal():
call = relay.Call(v1, [relay.const(1), relay.const(2), v2,
relay.Tuple([])],
attr1, [tt1])
same = relay.Call(v1, basic_args, attr1, [tt1])
- assert alpha_equal(call, same)
+ assert sequal(call, same)
different_fn = relay.Call(v2, basic_args, attr1, [tt1])
- assert not alpha_equal(call, different_fn)
+ assert not sequal(call, different_fn)
fewer_args = relay.Call(v1, [relay.const(1), relay.const(2), v2], attr1,
[tt1])
- assert not alpha_equal(call, fewer_args)
+ assert not sequal(call, fewer_args)
reordered_args = relay.Call(v1, [relay.const(2), relay.const(1),
relay.Tuple([]), v2], attr1, [tt1])
- assert not alpha_equal(call, reordered_args)
+ assert not sequal(call, reordered_args)
different_args = relay.Call(v1, [relay.const(1), relay.const(2),
relay.const(3)],
attr1, [tt1])
- assert not alpha_equal(call, different_args)
+ assert not sequal(call, different_args)
more_args = relay.Call(v1, [relay.const(1), relay.const(2), v2,
relay.Tuple([]),
relay.const(3), relay.const(4)], attr1, [tt1])
- assert not alpha_equal(call, more_args)
+ assert not sequal(call, more_args)
different_attrs = relay.Call(v1, basic_args, attr2, [tt1])
- assert not alpha_equal(call, different_attrs)
+ assert not sequal(call, different_attrs)
same_attrs = relay.Call(v1, basic_args, attr1_same, [tt1])
- assert alpha_equal(call, same_attrs)
+ assert sequal(call, same_attrs)
no_type_args = relay.Call(v1, basic_args, attr1)
- assert not alpha_equal(call, no_type_args)
+ assert not sequal(call, no_type_args)
more_type_args = relay.Call(v1, basic_args, attr1, [tt1, tt2])
- assert not alpha_equal(call, more_type_args)
+ assert not sequal(call, more_type_args)
different_type_arg = relay.Call(v1, basic_args, attr1, [tt2])
- assert not alpha_equal(call, different_type_arg)
+ assert not sequal(call, different_type_arg)
-def test_let_alpha_equal():
+def test_let_sequal():
tt1 = relay.TensorType((), "float32")
tt2 = relay.TensorType((), "int8")
v1 = relay.Var("v1")
@@ -504,57 +505,57 @@ def test_let_alpha_equal():
let = relay.Let(v1, relay.const(2), v1)
mapped = relay.Let(v2, relay.const(2), v2)
- assert alpha_equal(let, mapped)
+ assert sequal(let, mapped)
mismatched_var = relay.Let(v2, relay.const(2), v3)
- assert not alpha_equal(let, mismatched_var)
+ assert not sequal(let, mismatched_var)
different_value = relay.Let(v2, relay.const(3), v2)
- assert not alpha_equal(let, different_value)
+ assert not sequal(let, different_value)
different_body = relay.Let(v2, relay.const(3), relay.const(12))
- assert not alpha_equal(let, different_body)
+ assert not sequal(let, different_body)
# specified types must match
let_with_type = relay.Let(v1_wtype, relay.const(2), v1_wtype)
same_type = relay.Let(v1_wtype, relay.const(2), v1_wtype)
- assert alpha_equal(let_with_type, same_type)
- assert not alpha_equal(let, let_with_type)
+ assert sequal(let_with_type, same_type)
+ assert not sequal(let, let_with_type)
v2 = relay.Var("v1", tt2)
different_type = relay.Let(v2, relay.const(2), v2)
- assert not alpha_equal(let_with_type, different_type)
+ assert not sequal(let_with_type, different_type)
-def test_if_alpha_equal():
+def test_if_sequal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
if_sample = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2),
relay.const(3)]))
same = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2),
relay.const(3)]))
- assert alpha_equal(if_sample, same)
+ assert sequal(if_sample, same)
different_cond = relay.If(v2, relay.const(1), relay.Tuple([relay.const(2),
relay.const(3)]))
- assert not alpha_equal(if_sample, different_cond)
+ assert not sequal(if_sample, different_cond)
different_true = relay.If(v1, relay.const(2), relay.Tuple([relay.const(2),
relay.const(3)]))
- assert not alpha_equal(if_sample, different_true)
+ assert not sequal(if_sample, different_true)
different_false = relay.If(v1, relay.const(1), relay.Tuple([]))
- assert not alpha_equal(if_sample, different_false)
+ assert not sequal(if_sample, different_false)
-def test_constructor_alpha_equal():
+def test_constructor_sequal():
# smoke test: it should be pointer equality
mod = tvm.IRModule()
p = relay.prelude.Prelude(mod)
- assert alpha_equal(p.nil, p.nil)
- assert alpha_equal(p.cons, p.cons)
- assert not alpha_equal(p.nil, p.cons)
+ assert sequal(p.nil, p.nil)
+ assert sequal(p.cons, p.cons)
+ assert not sequal(p.nil, p.cons)
-def test_match_alpha_equal():
+def test_match_sequal():
mod = tvm.IRModule()
p = relay.prelude.Prelude(mod)
@@ -604,27 +605,28 @@ def test_match_alpha_equal():
p.cons(x, p.nil()))
])
- assert alpha_equal(match, match)
- assert alpha_equal(match, equivalent)
- assert not alpha_equal(match, no_cons)
- assert not alpha_equal(match, no_nil)
- assert not alpha_equal(match, empty)
- assert not alpha_equal(match, different_data)
- assert not alpha_equal(match, different_order)
- assert not alpha_equal(match, different_nil)
- assert not alpha_equal(match, different_cons)
- assert not alpha_equal(match, another_case)
- assert not alpha_equal(match, wrong_constructors)
-
-
-def test_op_alpha_equal():
+ tvm.ir.assert_structural_equal(match, match)
+ assert sequal(match, match)
+ assert sequal(match, equivalent)
+ assert not sequal(match, no_cons)
+ assert not sequal(match, no_nil)
+ assert not sequal(match, empty)
+ assert not sequal(match, different_data)
+ assert not sequal(match, different_order)
+ assert not sequal(match, different_nil)
+ assert not sequal(match, different_cons)
+ assert not sequal(match, another_case)
+ assert not sequal(match, wrong_constructors)
+
+
+def test_op_sequal():
# only checks names
op1 = relay.op.get("add")
op2 = relay.op.get("add")
- assert alpha_equal(op1, op2)
+ assert sequal(op1, op2)
op3 = relay.op.get("take")
- assert not alpha_equal(op1, op3)
+ assert not sequal(op1, op3)
def test_graph_equal():
@@ -638,14 +640,14 @@ def test_graph_equal():
z3 = relay.add(relay.add(x, x), relay.add(x, x))
- assert alpha_equal(z0, z1)
- assert alpha_equal(z0, z1)
+ assert sequal(z0, z1)
+ assert sequal(z0, z1)
# z3's dataflow format is different from z0
# z0 is computed from a common y0 node
# Relay view them as different programs
# Check the difference in the text format.
- assert not alpha_equal(z0, z3)
+ assert not sequal(z0, z3)
def test_hash_unequal():
x1 = relay.var("x1", shape=(10, 10), dtype="float32")
@@ -677,7 +679,7 @@ def test_tuple_match():
b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a),
relay.PatternVar(b)]), a + b)
y = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
- assert analysis.alpha_equal(x, y)
+ assert sequal(x, y)
assert analysis.structural_hash(x) == analysis.structural_hash(y)
@@ -697,34 +699,34 @@ def test_fn_attribute():
add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.tir.StringImm("test"))
add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
- assert not relay.analysis.alpha_equal(add_1_fn, add_fn)
- assert not relay.analysis.alpha_equal(add_fn, add_1_fn)
+ assert not sequal(add_1_fn, add_fn)
+ assert not sequal(add_fn, add_1_fn)
if __name__ == "__main__":
- test_tensor_type_alpha_equal()
- test_incomplete_type_alpha_equal()
- test_constant_alpha_equal()
- test_type_node_alpha_equal()
- test_type_node_incompatible_alpha_equal()
- test_expr_node_incompatible_alpha_equal()
- test_func_type_alpha_equal()
- test_tuple_type_alpha_equal()
- test_type_relation_alpha_equal()
- test_type_call_alpha_equal()
- test_constant_alpha_equal()
- test_global_var_alpha_equal()
- test_tuple_alpha_equal()
- test_tuple_get_item_alpha_equal()
- test_function_alpha_equal()
+ test_tensor_type_sequal()
+ test_incomplete_type_sequal()
+ test_constant_sequal()
+ test_type_node_sequal()
+ test_type_node_incompatible_sequal()
+ test_expr_node_incompatible_sequal()
+ test_func_type_sequal()
+ test_tuple_type_sequal()
+ test_type_relation_sequal()
+ test_type_call_sequal()
+ test_constant_sequal()
+ test_global_var_sequal()
+ test_tuple_sequal()
+ test_tuple_get_item_sequal()
+ test_function_sequal()
test_function_attr()
- test_call_alpha_equal()
- test_let_alpha_equal()
- test_if_alpha_equal()
- test_constructor_alpha_equal()
- test_match_alpha_equal()
- test_op_alpha_equal()
- test_var_alpha_equal()
+ test_call_sequal()
+ test_let_sequal()
+ test_if_sequal()
+ test_constructor_sequal()
+ test_match_sequal()
+ test_op_sequal()
+ test_var_sequal()
test_graph_equal()
test_hash_unequal()
test_fn_attribute()
diff --git a/tests/python/relay/test_pass_dead_code_elimination.py
b/tests/python/relay/test_pass_dead_code_elimination.py
index 604ec89..3a0bf1f 100644
--- a/tests/python/relay/test_pass_dead_code_elimination.py
+++ b/tests/python/relay/test_pass_dead_code_elimination.py
@@ -57,14 +57,14 @@ def run_opt_pass(expr, opt_pass):
def test_let():
orig = relay.Let(e.x, e.y, e.z)
orig = run_opt_pass(orig, transform.DeadCodeElimination())
- assert alpha_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
+ assert tvm.ir.structural_equal(Function(free_vars(orig), orig),
Function([e.z], e.z))
def test_used_let():
orig = relay.Let(e.c, e.one, e.c + e.c)
orig = run_opt_pass(orig, transform.DeadCodeElimination())
expected = relay.Let(e.c, e.one, e.c + e.c)
- assert alpha_equal(Function([e.c], orig), Function([e.c], expected))
+ assert tvm.ir.structural_equal(Function([], orig), Function([], expected))
def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
@@ -75,7 +75,7 @@ def test_inline():
def test_chain_unused_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
orig = run_opt_pass(orig, transform.DeadCodeElimination())
- assert alpha_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
+ assert tvm.ir.structural_equal(Function(free_vars(orig), orig),
Function([e.e], e.e))
def use_f(func):
@@ -111,13 +111,13 @@ def test_recursion_dead():
x = relay.Let(e.a, e.one, e.three)
dced_f = lambda f: x
dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
- assert alpha_equal(dced, e.three)
+ assert tvm.ir.structural_equal(dced, e.three)
def test_op_let():
dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two),
transform.DeadCodeElimination())
- assert alpha_equal(dced, add(e.three, e.two))
+ assert tvm.ir.structural_equal(dced, add(e.three, e.two))
def test_tuple_get_item():
@@ -126,10 +126,10 @@ def test_tuple_get_item():
a = relay.Var('a')
g = relay.TupleGetItem(t, 0)
dced = run_opt_pass(g, transform.DeadCodeElimination())
- assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g),
g))
+ assert tvm.ir.structural_equal(Function(free_vars(dced), dced),
Function(free_vars(g), g))
orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
dced = run_opt_pass(orig, transform.DeadCodeElimination())
- assert alpha_equal(Function(free_vars(dced), dced), Function(free_vars(g),
g))
+ assert tvm.ir.structural_equal(Function(free_vars(dced), dced),
Function(free_vars(g), g))
@pytest.mark.timeout(timeout=10, method="thread")
diff --git a/tests/python/relay/test_pass_partial_eval.py
b/tests/python/relay/test_pass_partial_eval.py
index f54dd6b..1299084 100644
--- a/tests/python/relay/test_pass_partial_eval.py
+++ b/tests/python/relay/test_pass_partial_eval.py
@@ -72,7 +72,7 @@ def test_tuple():
f = Function([x], body, None, [t])
expected = relay.Function([x], x, None, [t])
expected = run_opt_pass(expected, transform.InferType())
- assert alpha_equal(dcpe(f), expected)
+ assert tvm.ir.structural_equal(dcpe(f), expected)
def test_const_inline():
@@ -80,7 +80,7 @@ def test_const_inline():
d = Var("d", t)
double = Function([d], d + d)
orig = double(const(4.0))
- assert alpha_equal(dcpe(orig), const(8.0))
+ assert tvm.ir.structural_equal(dcpe(orig), const(8.0))
def test_ref():
@@ -93,7 +93,7 @@ def test_ref():
body = Let(r, RefCreate(d), body)
square = Function([d], body)
expected = run_opt_pass(Function([d], d * d), transform.InferType())
- assert alpha_equal(dcpe(square), expected)
+ assert tvm.ir.structural_equal(dcpe(square), expected)
def test_empty_ad():
@@ -105,7 +105,7 @@ def test_empty_ad():
g = dcpe(f, grad=True)
expected = Function([d], Tuple([d, Tuple([op.ones_like(d)])]))
expected = run_opt_pass(expected, transform.InferType())
- assert alpha_equal(g, expected)
+ assert tvm.ir.structural_equal(g, expected)
def test_ad():
@@ -180,7 +180,7 @@ def test_head_cons():
body = hd(p.cons(x, p.nil()))
f = Function([x], body, None, [t])
res = dcpe(f, mod)
- assert alpha_equal(res, Function([x], x, t, [t]))
+ assert tvm.ir.structural_equal(res, Function([x], x, t, [t]))
def test_map():
@@ -197,7 +197,7 @@ def test_map():
expected = mod["main"]
orig = Function([], orig)
res = dcpe(orig, mod=mod)
- assert alpha_equal(res.body, expected.body)
+ assert tvm.ir.structural_equal(res.body, expected.body)
def test_loop():
@@ -211,7 +211,7 @@ def test_loop():
expected = mod["main"].body
call = Function([], loop(const(1)))
res = dcpe(call, mod=mod)
- assert alpha_equal(res.body, expected)
+ assert tvm.ir.structural_equal(res.body, expected)
def test_swap_loop():
@@ -226,7 +226,7 @@ def test_swap_loop():
prog = loop(make_nat_expr(p, 1), make_nat_expr(p, 2))
res = Function([], prog)
res = dcpe(res, mod=mod)
- assert alpha_equal(prog, res.body)
+ assert tvm.ir.structural_equal(prog, res.body)
def test_abs_diff():
@@ -248,7 +248,7 @@ def test_abs_diff():
orig = diff(make_nat_expr(p, 7), make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
- assert alpha_equal(res.body, make_nat_expr(p, 4))
+ assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 4))
def test_match_nat_id():
@@ -265,7 +265,7 @@ def test_match_nat_id():
orig = nat_id(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
- assert alpha_equal(res.body, make_nat_expr(p, 3))
+ assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_nat_id():
@@ -280,7 +280,7 @@ def test_nat_id():
orig = nat_id(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
- assert alpha_equal(res.body, make_nat_expr(p, 3))
+ assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_global_match_nat_id():
@@ -294,7 +294,7 @@ def test_global_match_nat_id():
orig = Match(make_nat_expr(p, 3), [z_case, s_case])
orig = Function([], orig)
res = dcpe(orig, mod=mod)
- assert alpha_equal(res.body, make_nat_expr(p, 3))
+ assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 3))
def test_double():
@@ -304,7 +304,7 @@ def test_double():
orig = p.double(make_nat_expr(p, 3))
orig = Function([], orig)
res = dcpe(orig, mod=mod)
- assert alpha_equal(res.body, make_nat_expr(p, 6))
+ assert tvm.ir.structural_equal(res.body, make_nat_expr(p, 6))
def test_concat():
diff --git a/tests/python/relay/test_pass_qnn_legalize.py
b/tests/python/relay/test_pass_qnn_legalize.py
index ed05096..b164821 100644
--- a/tests/python/relay/test_pass_qnn_legalize.py
+++ b/tests/python/relay/test_pass_qnn_legalize.py
@@ -134,7 +134,7 @@ def test_qnn_legalize_qnn_conv2d():
# Since same dtype, there should not be any transformation
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu
-mattr=+v8.2a,+dotprod'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
- assert alpha_equal(mod, legalized_mod)
+ assert tvm.ir.structural_equal(mod, legalized_mod)
################################################################
# Check transformations for platforms without fast Int8 support.
@@ -157,7 +157,7 @@ def test_qnn_legalize_qnn_conv2d():
# Check no transformation for Intel VNNI.
with tvm.target.create('llvm -mcpu=skylake-avx512'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
- assert alpha_equal(mod, legalized_mod)
+ assert tvm.ir.structural_equal(mod, legalized_mod)
# ARM - so check that transformation has happened.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu
-mattr=+v8.2a,+dotprod'):
@@ -221,7 +221,7 @@ def test_qnn_legalize_qnn_dense():
# Since same dtype, there should not be any transformation
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu
-mattr=+v8.2a,+dotprod'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
- assert alpha_equal(mod, legalized_mod)
+ assert tvm.ir.structural_equal(mod, legalized_mod)
################################################################
# Check transformations for platforms without fast Int8 support.
@@ -244,7 +244,7 @@ def test_qnn_legalize_qnn_dense():
# Check no transformation for Intel VNNI.
with tvm.target.create('llvm -mcpu=skylake-avx512'):
legalized_mod = relay.qnn.transform.Legalize()(mod)
- assert alpha_equal(mod, legalized_mod)
+ assert tvm.ir.structural_equal(mod, legalized_mod)
# ARM - so check that transformation has happened.
with tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu
-mattr=+v8.2a,+dotprod'):
diff --git a/tests/python/relay/test_pass_to_a_normal_form.py
b/tests/python/relay/test_pass_to_a_normal_form.py
index 2a6103e..29818f8 100644
--- a/tests/python/relay/test_pass_to_a_normal_form.py
+++ b/tests/python/relay/test_pass_to_a_normal_form.py
@@ -76,7 +76,7 @@ def test_order():
expected_output = relay.Let(b, y, expected_output)
expected_output = relay.Let(a, x, expected_output)
expected_output = run_opt_pass(expected_output, transform.InferType())
- assert alpha_equal(anf, expected_output)
+ assert tvm.ir.structural_equal(anf, expected_output)
def test_if():
@@ -93,7 +93,7 @@ def test_if():
expected_output = relay.Let(d, expected_output, d)
expected_output = relay.Let(c, cond, expected_output)
expected_output = run_opt_pass(expected_output, transform.InferType())
- assert alpha_equal(anf, expected_output)
+ assert tvm.ir.structural_equal(anf, expected_output)
# make sure we dont infinite loop.
diff --git a/tests/python/relay/test_pass_to_cps.py
b/tests/python/relay/test_pass_to_cps.py
index e2ac924..4aaa9a0 100644
--- a/tests/python/relay/test_pass_to_cps.py
+++ b/tests/python/relay/test_pass_to_cps.py
@@ -17,7 +17,7 @@
import numpy as np
import tvm
from tvm import relay
-from tvm.relay.analysis import alpha_equal, detect_feature
+from tvm.relay.analysis import detect_feature
from tvm.relay.transform import to_cps, un_cps
from tvm.relay.analysis import Feature
from tvm.relay.prelude import Prelude
diff --git a/tests/python/relay/test_type_infer.py
b/tests/python/relay/test_type_infer.py
index 74507ba..4591618 100644
--- a/tests/python/relay/test_type_infer.py
+++ b/tests/python/relay/test_type_infer.py
@@ -21,7 +21,6 @@ import tvm
from tvm import te
from tvm import relay
from tvm.relay import op, transform, analysis
-from tvm.relay.analysis import assert_alpha_equal
def run_infer_type(expr, mod=None):
@@ -360,7 +359,7 @@ def test_let_polymorphism():
body = relay.Let(id, relay.Function([x], x, xt, [xt]), body)
body = run_infer_type(body)
int32 = relay.TensorType((), "int32")
- assert_alpha_equal(body.checked_type, relay.TupleType([int32,
relay.TupleType([])]))
+ tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32,
relay.TupleType([])]))
if __name__ == "__main__":
diff --git a/tests/python/unittest/test_node_reflection.py
b/tests/python/unittest/test_node_reflection.py
index a25ba0a..f2848ff 100644
--- a/tests/python/unittest/test_node_reflection.py
+++ b/tests/python/unittest/test_node_reflection.py
@@ -25,7 +25,7 @@ def test_const_saveload_json():
z = z + z
json_str = tvm.ir.save_json(z)
zz = tvm.ir.load_json(json_str)
- assert tvm.ir.save_json(zz) == tvm.ir.save_json(z)
+ tvm.ir.assert_structural_equal(zz, z, map_free_vars=True)
def test_make_smap():
@@ -38,6 +38,7 @@ def test_make_smap():
arr = tvm.ir.load_json(json_str)
assert len(arr) == 1
assert arr[0]["z"].a == arr[0]["x"]
+ tvm.ir.assert_structural_equal(arr, [smap], map_free_vars=True)
def test_make_node():
@@ -90,7 +91,6 @@ def test_env_func():
if __name__ == "__main__":
test_env_func()
- test_make_attrs()
test_make_node()
test_make_smap()
test_const_saveload_json()
diff --git a/tests/python/unittest/test_tir_structural_equal.py
b/tests/python/unittest/test_tir_structural_equal.py
new file mode 100644
index 0000000..26f3085
--- /dev/null
+++ b/tests/python/unittest/test_tir_structural_equal.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import pytest
+from tvm import te
+
+
+def test_exprs():
+ # save load json
+ x = tvm.tir.const(1, "int32")
+ y = tvm.tir.const(10, "int32")
+ vx = te.var("x")
+ vy = te.var("y")
+ vz = te.var("z")
+
+ # test assert trigger.
+ with pytest.raises(ValueError):
+ tvm.ir.assert_structural_equal(x, y)
+
+ assert not tvm.ir.structural_equal(vx, vy)
+ assert tvm.ir.structural_equal(vx, vy, map_free_vars=True)
+ # corner case lhs:vx == rhs:vy, but cannot map it iteslf
+ assert not tvm.ir.structural_equal(vx + vx, vy + vx, map_free_vars=True)
+ # corner case lhs:vx == rhs:vy, lhs:vy == rhs:vx
+ assert tvm.ir.structural_equal(vx + vy, vy + vx, map_free_vars=True)
+ # corner case2: rolling remap.
+ assert tvm.ir.structural_equal(vx + vy + vz, vy + vz + vx,
map_free_vars=True)
+ assert not tvm.ir.structural_equal(vx + 1, vy + 1, map_free_vars=False)
+ # Defintition remap
+ assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx - 1),
+ tvm.tir.Let(vy, 1, vy - 1))
+ # Default same address free var remap
+ assert tvm.ir.structural_equal(tvm.tir.Let(vx, 1, vx // vz),
+ tvm.tir.Let(vy, 1, vy // vz))
+
+ zx = vx + vx
+ zy = vy + vy
+ assert tvm.ir.structural_equal(zx * zx, zx * zx)
+ assert tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=True)
+ assert not tvm.ir.structural_equal(zx * zx, zy * zy, map_free_vars=False)
+ assert tvm.ir.structural_equal(zx * zx, (vx + vx) * (vx + vx),
+ map_free_vars=False)
+
+
+def test_prim_func():
+ x = te.var('x')
+ y = te.var('y')
+ # counter example of same equality
+ func0 = tvm.tir.PrimFunc(
+ [x, y], tvm.tir.Evaluate(x + y))
+ func1 = tvm.tir.PrimFunc(
+ [x, y], tvm.tir.Evaluate(y + x))
+ assert not tvm.ir.structural_equal(func0, func1)
+
+ # new cases
+ b = tvm.tir.decl_buffer((x,), "float32")
+ stmt = tvm.tir.LetStmt(
+ x, 10, tvm.tir.Evaluate(x + 1))
+ func0 = tvm.tir.PrimFunc(
+ [x, y, b], stmt)
+ # easiest way to deep copy is via save/load
+ func1 = tvm.ir.load_json(tvm.ir.save_json(func0))
+ tvm.ir.assert_structural_equal(func0, func1)
+
+ data0 = tvm.nd.array([1, 2, 3])
+ data1 = tvm.nd.array([1, 2, 3])
+ # attributes and ndarrays
+ func0 = func0.with_attr("data", data0)
+ func1 = func1.with_attr("data", data1)
+ # IRModules
+ mod0 = tvm.IRModule.from_expr(func0)
+ mod1 = tvm.IRModule.from_expr(func1)
+ tvm.ir.assert_structural_equal(mod0, mod1)
+
+
+def test_attrs():
+ x = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
+ y = tvm.ir.make_node("attrs.TestAttrs", axis=1, name="xx")
+ z = tvm.ir.make_node("attrs.TestAttrs", axis=2, name="xx")
+ tvm.ir.assert_structural_equal(y, x)
+ assert not tvm.ir.structural_equal(y, z)
+
+
+
+if __name__ == "__main__":
+ test_exprs()
+ test_prim_func()
+ test_attrs()