This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 919ae88 [REFACTOR][IR] alpha_equal to structural_equal (#5161)
919ae88 is described below
commit 919ae889638555b82de2d124d5f3e08d76bf789b
Author: Zhi <[email protected]>
AuthorDate: Sun Mar 29 09:58:58 2020 -0700
[REFACTOR][IR] alpha_equal to structural_equal (#5161)
---
include/tvm/ir/type.h | 4 +-
include/tvm/relay/analysis.h | 55 --
python/tvm/ir/type.py | 3 +-
python/tvm/relay/__init__.py | 1 -
python/tvm/relay/analysis/analysis.py | 72 ---
src/ir/module.cc | 19 +-
src/relay/analysis/alpha_equal.cc | 628 ---------------------
src/relay/analysis/type_solver.cc | 7 +-
src/relay/backend/compile_engine.h | 3 +-
src/relay/backend/vm/lambda_lift.cc | 4 +-
src/relay/op/tensor/transform.cc | 7 +-
src/relay/transforms/lazy_gradient_init.cc | 22 +-
src/relay/transforms/pattern_util.h | 3 +-
src/relay/transforms/to_cps.cc | 2 +-
tests/cpp/relay_pass_alpha_equal.cc | 67 ---
tests/cpp/relay_pass_type_infer_test.cc | 3 +-
tests/cpp/relay_transform_sequential.cc | 3 +-
tests/python/frontend/caffe2/test_graph.py | 3 +-
tests/python/frontend/mxnet/test_graph.py | 2 +-
.../relay/test_analysis_extract_fused_functions.py | 2 +-
tests/python/relay/test_annotate_target.py | 2 +-
tests/python/relay/test_call_graph.py | 2 +-
tests/python/relay/test_ir_bind.py | 4 +-
tests/python/relay/test_ir_nodes.py | 3 +-
.../python/relay/test_ir_structural_equal_hash.py | 2 +-
tests/python/relay/test_ir_text_printer.py | 6 +-
tests/python/relay/test_op_level10.py | 4 +-
tests/python/relay/test_pass_alter_op_layout.py | 46 +-
tests/python/relay/test_pass_annotation.py | 12 +-
tests/python/relay/test_pass_canonicalize_cast.py | 2 +-
.../relay/test_pass_combine_parallel_conv2d.py | 8 +-
.../relay/test_pass_combine_parallel_dense.py | 6 +-
tests/python/relay/test_pass_convert_op_layout.py | 22 +-
.../relay/test_pass_dead_code_elimination.py | 6 +-
.../relay/test_pass_eliminate_common_subexpr.py | 4 +-
tests/python/relay/test_pass_eta_expand.py | 6 +-
tests/python/relay/test_pass_fold_constant.py | 14 +-
tests/python/relay/test_pass_fold_scale_axis.py | 22 +-
tests/python/relay/test_pass_fuse_ops.py | 30 +-
tests/python/relay/test_pass_gradient.py | 4 +-
tests/python/relay/test_pass_inline.py | 32 +-
tests/python/relay/test_pass_legalize.py | 8 +-
tests/python/relay/test_pass_manager.py | 6 +-
tests/python/relay/test_pass_merge_composite.py | 23 +-
tests/python/relay/test_pass_partial_eval.py | 9 +-
tests/python/relay/test_pass_partition_graph.py | 10 +-
tests/python/relay/test_pass_qnn_legalize.py | 6 +-
.../relay/test_pass_remove_unused_functions.py | 2 +-
tests/python/relay/test_pass_simplify_inference.py | 4 +-
tests/python/relay/test_pass_to_a_normal_form.py | 2 +-
tests/python/relay/test_type_functor.py | 4 +-
tests/python/unittest/test_ir_type.py | 3 +-
52 files changed, 208 insertions(+), 1016 deletions(-)
diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index 6f6c66a..0e65758 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -498,7 +498,9 @@ class IncompleteTypeNode : public TypeNode {
}
bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal)
const {
- return equal(kind, other->kind);
+ return
+ equal(kind, other->kind) &&
+ equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h
index fe8fae5..51eae5a 100644
--- a/include/tvm/relay/analysis.h
+++ b/include/tvm/relay/analysis.h
@@ -65,61 +65,6 @@ TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod);
TVM_DLL bool ConstantCheck(const Expr& e);
/*!
- * \brief Compare two expressions for structural equivalence.
- *
- * This comparison operator respects scoping and compares
- * expressions without regard to variable choice.
- *
- * For example: `let x = 1 in x` is equal to `let y = 1 in y`.
- *
- * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
- * for more details.
- *
- * \param e1 The left hand expression.
- * \param e2 The right hand expression.
- *
- * \return true if equal, otherwise false
- */
-TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
-
-/*!
- * \brief Compare two types for structural equivalence.
- *
- * This comparison operator respects scoping and compares
- * expressions without regard to variable choice.
- *
- * For example: `forall s, Tensor[f32, s]` is equal to
- * `forall w, Tensor[f32, w]`.
- *
- * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
- * for more details.
- *
- * \param t1 The left hand type.
- * \param t2 The right hand type.
- *
- * \return true if equal, otherwise false
- */
-TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
-
-/*!
- * \brief Compare two patterns for structural equivalence.
- *
- * This comparison operator respects scoping and compares
- * patterns without regard to variable choice.
- *
- * For example: `A(x, _, y)` is equal to `A(z, _, a)`.
- *
- * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
- * for more details.
- *
- * \param t1 The left hand pattern.
- * \param t2 The right hand pattern.
- *
- * \return true if equal, otherwise false
- */
-TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2);
-
-/*!
* \brief Check that each Var is only bound once.
*
* For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py
index a61c6e4..e980011 100644
--- a/python/tvm/ir/type.py
+++ b/python/tvm/ir/type.py
@@ -16,6 +16,7 @@
# under the License.
"""Unified type system in the project."""
from enum import IntEnum
+import tvm
import tvm._ffi
from .base import Node
@@ -26,7 +27,7 @@ class Type(Node):
"""The base class of all types."""
def __eq__(self, other):
"""Compare two types for structural equivalence."""
- return bool(_ffi_api.type_alpha_equal(self, other))
+ return bool(tvm.ir.structural_equal(self, other))
def __ne__(self, other):
return not self.__eq__(other)
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index 95545c8..1517cf9 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -33,7 +33,6 @@ from . import parser
from . import transform
from . import analysis
-from .analysis import alpha_equal
from .build_module import build, create_executor, optimize
from .transform import build_config
from . import debug
diff --git a/python/tvm/relay/analysis/analysis.py
b/python/tvm/relay/analysis/analysis.py
index 722f3b0..b09a40b 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -220,78 +220,6 @@ def all_type_vars(expr, mod=None):
return _ffi_api.all_type_vars(expr, use_mod)
-def alpha_equal(lhs, rhs):
- """Compare two Relay expr for structural equivalence (alpha equivalence).
-
- Parameters
- ----------
- lhs : tvm.relay.Expr
- One of the input Expression.
-
- rhs : tvm.relay.Expr
- One of the input Expression.
-
- Returns
- -------
- result : bool
- True iff lhs is alpha equal to rhs.
- """
- return bool(_ffi_api._alpha_equal(lhs, rhs))
-
-
-def assert_alpha_equal(lhs, rhs):
- """Assert that two Relay expr is structurally equivalent. (alpha
equivalence).
-
- Parameters
- ----------
- lhs : tvm.relay.Expr
- One of the input Expression.
-
- rhs : tvm.relay.Expr
- One of the input Expression.
- """
- _ffi_api._assert_alpha_equal(lhs, rhs)
-
-
-def graph_equal(lhs, rhs):
- """Compare two Relay expr for data-flow equivalence.
- The difference between this and alpha-equality is that
- variables are not expected to match between lhs and rhs;
- they are treated as sources and are mapped between each other.
-
- Parameters
- ----------
- lhs : tvm.relay.Expr
- One of the input Expression.
-
- rhs : tvm.relay.Expr
- One of the input Expression.
-
- Returns
- -------
- result : bool
- True iff lhs is data-flow equivalent to rhs.
- """
- return bool(_ffi_api._graph_equal(lhs, rhs))
-
-
-def assert_graph_equal(lhs, rhs):
- """Compare two Relay expr for data-flow equivalence.
- The difference between this and alpha-equality is that
- variables are not expected to match between lhs and rhs;
- they are treated as sources and are mapped between each other.
-
- Parameters
- ----------
- lhs : tvm.relay.Expr
- One of the input Expression.
-
- rhs : tvm.relay.Expr
- One of the input Expression.
- """
- _ffi_api._assert_graph_equal(lhs, rhs)
-
-
def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
diff --git a/src/ir/module.cc b/src/ir/module.cc
index de09314..c7474de 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -23,6 +23,7 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/module.h>
+#include <tvm/node/structural_equal.h>
// NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
@@ -194,12 +195,11 @@ relay::Function RunTypeCheck(const IRModule& mod,
<< AsText(func, false)
<< std::endl;
}
- func =
- relay::Function(concat(func->params, fv),
- func->body,
- func->ret_type,
- concat(func->type_params, ftv),
- func->attrs);
+ func = relay::Function(concat(func->params, fv),
+ func->body,
+ func->ret_type,
+ concat(func->type_params, ftv),
+ func->attrs);
// Type check the item before we add it to the module.
relay::Function checked_func = InferType(func, mod, var);
return checked_func;
@@ -222,7 +222,7 @@ void IRModuleNode::Add(const GlobalVar& var,
CHECK(update)
<< "Already have definition for " << var->name_hint;
auto old_type = functions[var]->checked_type();
- CHECK(relay::AlphaEqual(type, old_type))
+ CHECK(tvm::StructuralEqual()(type, old_type))
<< "Module#update changes type, not possible in this mode.";
}
var->checked_type_ = type;
@@ -353,9 +353,8 @@ IRModule IRModule::FromExpr(
if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
} else {
- func = relay::Function(
- relay::FreeVars(expr), expr, Type(),
- relay::FreeTypeVars(expr, mod), {});
+ func = relay::Function(relay::FreeVars(expr), expr, Type(),
+ relay::FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVar("main");
mod->Add(main_gv, func);
diff --git a/src/relay/analysis/alpha_equal.cc
b/src/relay/analysis/alpha_equal.cc
deleted file mode 100644
index 28c7681..0000000
--- a/src/relay/analysis/alpha_equal.cc
+++ /dev/null
@@ -1,628 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file src/relay/analysis/alpha_equal.cc
- * \brief Alpha equality check by deep comparing two nodes.
- */
-#include <tvm/ir/type_functor.h>
-#include <tvm/tir/ir_pass.h>
-#include <tvm/relay/expr_functor.h>
-#include <tvm/relay/pattern_functor.h>
-#include <tvm/runtime/ndarray.h>
-#include <tvm/relay/analysis.h>
-#include <tvm/relay/op_attr_types.h>
-#include <tvm/relay/attrs/nn.h>
-#include "../../ir/attr_functor.h"
-
-
-namespace tvm {
-namespace relay {
-
-// Alpha Equal handler for Relay.
-class AlphaEqualHandler:
- public AttrsEqualHandler,
- public TypeFunctor<bool(const Type&, const Type&)>,
- public ExprFunctor<bool(const Expr&, const Expr&)>,
- public PatternFunctor<bool(const Pattern&, const Pattern&)> {
- public:
- explicit AlphaEqualHandler(bool map_free_var, bool assert_mode)
- : map_free_var_(map_free_var), assert_mode_(assert_mode) { }
-
- /*!
- * Check equality of two nodes.
- * \param lhs The left hand operand.
- * \param rhs The right hand operand.
- * \return The comparison result.
- */
- bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
- return VisitAttr(lhs, rhs);
- }
- /*!
- * Check equality of two attributes.
- * \param lhs The left hand operand.
- * \param rhs The right hand operand.
- * \return The comparison result.
- */
- bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
- auto compute = [&]() {
- return VisitAttr(lhs, rhs);
- };
- return Compare(compute(), lhs, rhs);
- }
- /*!
- * Check equality of two types.
- * \param lhs The left hand operand.
- * \param rhs The right hand operand.
- * \return the comparison result.
- */
- bool TypeEqual(const Type& lhs, const Type& rhs) {
- auto compute = [&]() {
- if (lhs.same_as(rhs)) return true;
- if (!lhs.defined() || !rhs.defined()) return false;
- return this->VisitType(lhs, rhs);
- };
- return Compare(compute(), lhs, rhs);
- }
-
- bool Compare(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
- if (assert_mode_) {
- CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" <<
AsText(rhs, true);
- }
- return result;
- }
- /*!
- * Check equality of two expressions.
- *
- * \note We run graph structural equality checking when comparing two Exprs.
- * This means that AlphaEqualHandler can only be used once for each pair.
- * The equality checker checks data-flow equvalence of the Expr DAG.
- * This function also runs faster as it memomizes equal_map.
- *
- * \param lhs The left hand operand.
- * \param rhs The right hand operand.
- * \return The comparison result.
- */
- bool ExprEqual(const Expr& lhs, const Expr& rhs) {
- auto compute = [&]() {
- if (lhs.same_as(rhs)) return true;
- if (!lhs.defined() || !rhs.defined()) return false;
- auto it = equal_map_.find(lhs);
- if (it != equal_map_.end()) {
- return it->second.same_as(rhs);
- }
- if (this->VisitExpr(lhs, rhs)) {
- equal_map_[lhs] = rhs;
- return true;
- } else {
- return false;
- }
- };
- return Compare(compute(), lhs, rhs);
- }
-
- protected:
- // So that the new definition of equality in relay can be handled directly.
- // Specifically, if a DictAttr contains a value defined by a relay AST.
- // We want to able to recursively check the equality in the attr defined by
the relay AST.
- bool VisitAttr(const ObjectRef& lhs, const ObjectRef& rhs) final {
- if (lhs.same_as(rhs)) return true;
- if (!lhs.defined() && rhs.defined()) return false;
- if (!rhs.defined() && lhs.defined()) return false;
- if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
- if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return
false;
- return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
- }
- if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
- if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return
false;
- return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
- }
- if (const auto lhsm = lhs.as<IRModuleNode>()) {
- auto rhsm = rhs.as<IRModuleNode>();
- if (!rhsm) return false;
- if (lhsm->functions.size() != rhsm->functions.size()) return false;
- for (const auto& p : lhsm->functions) {
- if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
- }
- if (lhsm->type_definitions.size() != rhsm->type_definitions.size())
return false;
- for (const auto& p : lhsm->type_definitions) {
- if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
- !Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
- return false;
- }
- }
- return true;
- }
- // Fall back to the object equal case.
- return AttrsEqualHandler::VisitAttr(lhs, rhs);
- }
- /*!
- * \brief Check if data type equals each other.
- * \param lhs The left hand operand.
- * \param rhs The right hand operand.
- * \return The compare result.
- */
- bool DataTypeEqual(const DataType& lhs, const DataType& rhs) {
- return lhs == rhs;
- }
- /*!
- * \brief Check Equality of leaf node of the graph.
- * if map_free_var_ is set to true, try to map via equal node.
- * \param lhs The left hand operand.
- * \param rhs The right hand operand.
- * \return The compare result.
- */
- bool LeafObjectEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
- if (lhs.same_as(rhs)) return true;
- auto it = equal_map_.find(lhs);
- if (it != equal_map_.end()) {
- return it->second.same_as(rhs);
- } else {
- if (map_free_var_) {
- if (lhs->type_index() != rhs->type_index()) return false;
- equal_map_[lhs] = rhs;
- return true;
- } else {
- return false;
- }
- }
- }
- using AttrsEqualHandler::VisitAttr_;
- bool VisitAttr_(const tvm::tir::VarNode* lhs, const ObjectRef& other) final {
- return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
- }
-
- // Type equality
- bool VisitType_(const TensorTypeNode* lhs, const Type& other) final {
- if (const TensorTypeNode* rhs = other.as<TensorTypeNode>()) {
- return (lhs->dtype == rhs->dtype &&
- AttrEqual(lhs->shape, rhs->shape));
- } else {
- return false;
- }
- }
-
- bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final {
- return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
- }
-
- bool VisitType_(const PrimTypeNode* lhs, const Type& other) final {
- if (const PrimTypeNode* rhs = other.as<PrimTypeNode>()) {
- return lhs->dtype == rhs->dtype;
- } else {
- return false;
- }
- }
-
- bool VisitType_(const PointerTypeNode* lhs, const Type& other) final {
- if (const PointerTypeNode* rhs = other.as<PointerTypeNode>()) {
- return TypeEqual(lhs->element_type, rhs->element_type);
- } else {
- return false;
- }
- }
-
- bool VisitType_(const TypeVarNode* lhs, const Type& other) final {
- if (const TypeVarNode* rhs = other.as<TypeVarNode>()) {
- if (lhs->kind != rhs->kind) return false;
- return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
- } else {
- return false;
- }
- }
-
- bool VisitType_(const FuncTypeNode* lhs, const Type& other) final {
- if (const FuncTypeNode* rhs = other.as<FuncTypeNode>()) {
- if (lhs->arg_types.size() != rhs->arg_types.size()) return false;
- if (lhs->type_params.size() != rhs->type_params.size()) return false;
- if (lhs->type_constraints.size() != rhs->type_constraints.size()) return
false;
- for (size_t i = 0; i < lhs->type_params.size(); ++i) {
- if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) {
- return false;
- }
- equal_map_[lhs->type_params[i]] = rhs->type_params[i];
- }
- for (size_t i = 0; i < lhs->arg_types.size(); i++) {
- if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
- }
- if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
- for (size_t i = 0; i < lhs->type_constraints.size(); i++) {
- if (!TypeEqual(lhs->type_constraints[i],
- rhs->type_constraints[i])) {
- return false;
- }
- }
- return true;
- } else {
- return false;
- }
- }
-
- bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
- if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
- if (lhs->func->name != rhs->func->name) return false;
- if (lhs->num_inputs != rhs->num_inputs) return false;
- if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
- if (lhs->args.size() != rhs->args.size()) return false;
- for (size_t i = 0; i < lhs->args.size(); ++i) {
- if (!TypeEqual(lhs->args[i], rhs->args[i])) return false;
- }
- return true;
- } else {
- return false;
- }
- }
-
- bool VisitType_(const TupleTypeNode* lhs, const Type& other) final {
- if (const TupleTypeNode* rhs = other.as<TupleTypeNode>()) {
- if (lhs->fields.size() != rhs->fields.size()) return false;
- for (size_t i = 0; i < lhs->fields.size(); ++i) {
- if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false;
- }
- return true;
- } else {
- return false;
- }
- }
-
- bool VisitType_(const RelayRefTypeNode* lhs, const Type& other) final {
- if (const RelayRefTypeNode* rhs = other.as<RelayRefTypeNode>()) {
- return TypeEqual(lhs->value, rhs->value);
- }
- return false;
- }
-
- bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final {
- return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
- }
-
- bool VisitType_(const TypeCallNode* lhs, const Type& other) final {
- const TypeCallNode* rhs = other.as<TypeCallNode>();
- if (rhs == nullptr
- || lhs->args.size() != rhs->args.size()
- || !TypeEqual(lhs->func, rhs->func)) {
- return false;
- }
-
- for (size_t i = 0; i < lhs->args.size(); ++i) {
- if (!TypeEqual(lhs->args[i], rhs->args[i])) {
- return false;
- }
- }
- return true;
- }
-
- bool VisitType_(const TypeDataNode* lhs, const Type& other) final {
- const TypeDataNode* rhs = other.as<TypeDataNode>();
- if (rhs == nullptr
- || lhs->type_vars.size() != rhs->type_vars.size()
- || !TypeEqual(lhs->header, rhs->header)) {
- return false;
- }
- for (size_t i = 0; i < lhs->type_vars.size(); ++i) {
- if (!TypeEqual(lhs->type_vars[i], rhs->type_vars[i])) {
- return false;
- }
- }
- for (size_t i = 0; i < lhs->constructors.size(); ++i) {
- if (!ExprEqual(lhs->constructors[i], rhs->constructors[i])) {
- return false;
- }
- }
- return true;
- }
-
- // Expr equal checking.
- bool NDArrayEqual(const runtime::NDArray& lhs,
- const runtime::NDArray& rhs) {
- if (lhs.defined() != rhs.defined()) {
- return false;
- } else if (lhs.same_as(rhs)) {
- return true;
- } else {
- auto ldt = lhs->dtype;
- auto rdt = rhs->dtype;
- CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
- CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
- if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits ==
rdt.bits) {
- size_t data_size = runtime::GetDataSize(*lhs.operator->());
- return std::memcmp(lhs->data, rhs->data, data_size) == 0;
- } else {
- return false;
- }
- }
- }
- // merge declaration of two variables together.
- bool MergeVarDecl(const Var& lhs, const Var& rhs) {
- if (lhs.same_as(rhs)) return true;
- if (!lhs.defined() || !rhs.defined()) return false;
- if (!TypeEqual(lhs->type_annotation,
- rhs->type_annotation)) return false;
- CHECK(!equal_map_.count(lhs))
- << "Duplicated declaration of variable " << lhs;
- equal_map_[lhs] = rhs;
- return true;
- }
-
- bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
- // This function will only be triggered if we are matching free variables.
- if (const VarNode* rhs = other.as<VarNode>()) {
- if (lhs->name_hint() != rhs->name_hint()) return false;
- if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
- return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
- } else {
- return false;
- }
- }
-
- bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final {
- if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) {
- // use name equality for global var for now.
- return lhs->name_hint == rhs->name_hint;
- }
- return false;
- }
-
- bool VisitExpr_(const TupleNode* lhs, const Expr& other) final {
- if (const TupleNode* rhs = other.as<TupleNode>()) {
- if (lhs->fields.size() != rhs->fields.size()) return false;
- for (size_t i = 0; i < lhs->fields.size(); ++i) {
- if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false;
- }
- return true;
- } else {
- return false;
- }
- }
-
- bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final {
- if (const FunctionNode* rhs = other.as<FunctionNode>()) {
- if (lhs->params.size() != rhs->params.size()) return false;
- if (lhs->type_params.size() != rhs->type_params.size()) return false;
- // map type parameter to be the same
- for (size_t i = 0; i < lhs->type_params.size(); ++i) {
- if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return
false;
- equal_map_[lhs->type_params[i]] = rhs->type_params[i];
- }
- // check parameter type annotations
- for (size_t i = 0; i < lhs->params.size(); ++i) {
- if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false;
- }
- // check return types.
- if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
- if (!AttrEqual(lhs->attrs, rhs->attrs)) return false;
- return ExprEqual(lhs->body, rhs->body);
- } else {
- return false;
- }
- }
-
- bool VisitExpr_(const CallNode* lhs, const Expr& other) final {
- if (const CallNode* rhs = other.as<CallNode>()) {
- if (!ExprEqual(lhs->op, rhs->op)) return false;
- if (lhs->args.size() != rhs->args.size()) return false;
- // skip type_args check for primitive ops.
- bool is_primitive = IsPrimitiveOp(lhs->op);
- if (!is_primitive) {
- if (lhs->type_args.size() != rhs->type_args.size()) {
- return false;
- }
- }
- for (size_t i = 0; i < lhs->args.size(); ++i) {
- if (!ExprEqual(lhs->args[i], rhs->args[i])) {
- return false;
- }
- }
-
- if (!is_primitive) {
- for (size_t i = 0; i < lhs->type_args.size(); ++i) {
- if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
- }
- }
- return AttrEqual(lhs->attrs, rhs->attrs);
- } else {
- return false;
- }
- }
-
- bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
- if (const LetNode* rhs = other.as<LetNode>()) {
- if (!MergeVarDecl(lhs->var, rhs->var)) return false;
- if (!ExprEqual(lhs->value, rhs->value)) return false;
- return ExprEqual(lhs->body, rhs->body);
- } else {
- return false;
- }
- }
-
- bool VisitExpr_(const IfNode* lhs, const Expr& other) final {
- if (const IfNode* rhs = other.as<IfNode>()) {
- return ExprEqual(lhs->cond, rhs->cond) &&
- ExprEqual(lhs->true_branch, rhs->true_branch) &&
- ExprEqual(lhs->false_branch, rhs->false_branch);
- } else {
- return false;
- }
- }
-
- bool VisitExpr_(const OpNode* lhs, const Expr& other) final {
- return lhs == other.get();
- }
-
- bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final {
- if (const ConstantNode* rhs = other.as<ConstantNode>()) {
- return NDArrayEqual(lhs->data, rhs->data);
- } else {
- return false;
- }
- }
-
- bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final {
- if (const TupleGetItemNode* rhs = other.as<TupleGetItemNode>()) {
- return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index;
- } else {
- return false;
- }
- }
-
- bool VisitExpr_(const RefCreateNode* lhs, const Expr& other) final {
- if (const RefCreateNode* rhs = other.as<RefCreateNode>()) {
- return ExprEqual(lhs->value, rhs->value);
- } else {
- return false;
- }
- }
-
- bool VisitExpr_(const RefReadNode* lhs, const Expr& other) final {
- if (const RefReadNode* rhs = other.as<RefReadNode>()) {
- return ExprEqual(lhs->ref, rhs->ref);
- } else {
- return false;
- }
- }
-
- bool VisitExpr_(const RefWriteNode* lhs, const Expr& other) final {
- if (const RefWriteNode* rhs = other.as<RefWriteNode>()) {
- return ExprEqual(lhs->ref, rhs->ref) && ExprEqual(lhs->value,
rhs->value);
- } else {
- return false;
- }
- }
-
- bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final {
- if (const ConstructorNode* rhs = other.as<ConstructorNode>()) {
- return lhs->name_hint == rhs->name_hint;
- }
- return false;
- }
-
- bool ClauseEqual(const Clause& lhs, const Clause& rhs) {
- return PatternEqual(lhs->lhs, rhs->lhs) && ExprEqual(lhs->rhs, rhs->rhs);
- }
-
- bool PatternEqual(const Pattern& lhs, const Pattern& rhs) {
- return Compare(VisitPattern(lhs, rhs), lhs, rhs);
- }
-
- bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other)
final {
- return other.as<PatternWildcardNode>();
- }
-
- bool VisitPattern_(const PatternVarNode* lhs, const Pattern& other) final {
- if (const auto* rhs = other.as<PatternVarNode>()) {
- return MergeVarDecl(lhs->var, rhs->var);
- }
- return false;
- }
-
- bool VisitPattern_(const PatternConstructorNode* lhs, const Pattern& other)
final {
- const auto* rhs = other.as<PatternConstructorNode>();
- if (rhs == nullptr
- || !ExprEqual(lhs->constructor, rhs->constructor)
- || lhs->patterns.size() != rhs->patterns.size()) {
- return false;
- }
-
- for (size_t i = 0; i < lhs->patterns.size(); i++) {
- if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
- return false;
- }
- }
- return true;
- }
-
- bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final {
- const auto* rhs = other.as<PatternTupleNode>();
- if (rhs == nullptr
- || lhs->patterns.size() != rhs->patterns.size()) {
- return false;
- }
-
- for (size_t i = 0; i < lhs->patterns.size(); i++) {
- if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
- return false;
- }
- }
- return true;
- }
-
- bool VisitExpr_(const MatchNode* lhs, const Expr& other) final {
- const MatchNode* rhs = other.as<MatchNode>();
-
- if (rhs == nullptr
- || !ExprEqual(lhs->data, rhs->data)
- || lhs->clauses.size() != rhs->clauses.size()
- || lhs->complete != rhs->complete) {
- return false;
- }
-
- for (size_t i = 0; i < lhs->clauses.size(); ++i) {
- if (!ClauseEqual(lhs->clauses[i], rhs->clauses[i])) {
- return false;
- }
- }
- return true;
- }
-
- private:
- // whether to map open terms.
- bool map_free_var_;
- // if in assert mode, must return true, and will throw error otherwise.
- bool assert_mode_;
- // renaming of NodeRef to indicate two nodes equals to each other
- std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_;
-};
-
-bool AlphaEqual(const Type& lhs, const Type& rhs) {
- return AlphaEqualHandler(false, false).TypeEqual(lhs, rhs);
-}
-
-bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
- return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
-}
-
-TVM_REGISTER_GLOBAL("relay.analysis._alpha_equal")
-.set_body_typed([](ObjectRef a, ObjectRef b) {
- return AlphaEqualHandler(false, false).Equal(a, b);
-});
-
-TVM_REGISTER_GLOBAL("ir.type_alpha_equal")
-.set_body_typed([](Type a, Type b) {
- return AlphaEqual(a, b);
-});
-
-TVM_REGISTER_GLOBAL("relay.analysis._assert_alpha_equal")
-.set_body_typed([](ObjectRef a, ObjectRef b) {
- bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
- CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are
not alpha equal";
-});
-
-TVM_REGISTER_GLOBAL("relay.analysis._graph_equal")
-.set_body_typed([](ObjectRef a, ObjectRef b) {
- return AlphaEqualHandler(true, false).Equal(a, b);
-});
-
-TVM_REGISTER_GLOBAL("relay.analysis._assert_graph_equal")
-.set_body_typed([](ObjectRef a, ObjectRef b) {
- bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
- CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are
not graph equal";
-});
-
-} // namespace relay
-} // namespace tvm
diff --git a/src/relay/analysis/type_solver.cc
b/src/relay/analysis/type_solver.cc
index c39df9d..650403c 100644
--- a/src/relay/analysis/type_solver.cc
+++ b/src/relay/analysis/type_solver.cc
@@ -21,6 +21,7 @@
* \file type_solver.cc
* \brief Type solver implementations.
*/
+#include <tvm/node/structural_equal.h>
#include <tvm/ir/type_functor.h>
#include <tvm/tir/op.h>
#include <string>
@@ -151,11 +152,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const
Type&, const Type&)> {
return rc.Check(t);
}
- // default: unify only if alpha-equal
+ // default: unify only if structural-equal
Type VisitTypeDefault_(const Object* op, const Type& tn) final {
ObjectRef nr = GetRef<ObjectRef>(op);
Type t1 = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
- if (!AlphaEqual(t1, tn)) {
+ if (!tvm::StructuralEqual()(t1, tn)) {
return Type(nullptr);
}
return t1;
@@ -216,7 +217,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const
Type&, const Type&)> {
auto tt1 = GetRef<TensorType>(op);
auto tt2 = GetRef<TensorType>(tt_node);
- if (AlphaEqual(tt1, tt2)) {
+ if (tvm::StructuralEqual()(tt1, tt2)) {
return std::move(tt1);
}
diff --git a/src/relay/backend/compile_engine.h
b/src/relay/backend/compile_engine.h
index 098211e..eec2bd3 100644
--- a/src/relay/backend/compile_engine.h
+++ b/src/relay/backend/compile_engine.h
@@ -25,6 +25,7 @@
#ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
+#include <tvm/node/structural_equal.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
@@ -268,7 +269,7 @@ inline bool CCacheKeyNode::Equal(
const CCacheKeyNode* other) const {
if (Hash() != other->Hash()) return false;
return this->target->str() == other->target->str() &&
- AlphaEqual(this->source_func, other->source_func);
+ tvm::StructuralEqual()(this->source_func, other->source_func);
}
} // namespace relay
diff --git a/src/relay/backend/vm/lambda_lift.cc
b/src/relay/backend/vm/lambda_lift.cc
index 7e7622c..398760f 100644
--- a/src/relay/backend/vm/lambda_lift.cc
+++ b/src/relay/backend/vm/lambda_lift.cc
@@ -22,6 +22,7 @@
* \brief Lift all local functions into global functions.
*/
+#include <tvm/node/structural_equal.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/support/logging.h>
@@ -161,7 +162,8 @@ class LambdaLifter : public ExprMutator {
if (module_->ContainGlobalVar(name)) {
const auto existing_func = module_->Lookup(name);
- CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash
collision";
+ CHECK(tvm::StructuralEqual()(lifted_func, existing_func))
+ << "lifted function hash collision";
// If an identical function already exists, use its global var.
global = module_->GetGlobalVar(name);
} else {
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 3d03b4a..87b4602 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -2142,7 +2142,12 @@ Expr MakeSplit(Expr data,
TVM_REGISTER_GLOBAL("relay.op._make.split")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
if (args.type_codes[1] == kDLInt) {
- *rv = MakeSplit(args[0], tir::make_const(DataType::Int(64),
int64_t(args[1])), args[2]);
+ // Note: we change it from Int(64) to Int(32) for now as
+ // combine_parallel_dense will transform the graph with Int(32).
+ // More invetigation is needs to check which one we should use.
+ *rv = MakeSplit(args[0],
+ tir::make_const(DataType::Int(32),
static_cast<int>(args[1])),
+ args[2]);
} else {
*rv = MakeSplit(args[0], args[1], args[2]);
}
diff --git a/src/relay/transforms/lazy_gradient_init.cc
b/src/relay/transforms/lazy_gradient_init.cc
index ba6ca05..e6248f1 100644
--- a/src/relay/transforms/lazy_gradient_init.cc
+++ b/src/relay/transforms/lazy_gradient_init.cc
@@ -59,6 +59,7 @@
* Thus, it is necessary to wrap this outer function so that the input/output
types remain the same
*/
+#include <tvm/node/structural_equal.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/ir/type_functor.h>
@@ -93,7 +94,7 @@ class InputVisitor: public ExprFunctor<Expr(const Expr&,
const Type&)> {
Expr WrapExpr(const Expr expr, const Type& type) {
if (type.as<TensorTypeNode>()) {
return Call(module_->GetConstructor("GradCell", "Raw"),
- {expr}, Attrs(), {type});
+ {expr}, Attrs(), {type});
} else if (auto* type_anno = type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
for (size_t i = 0; i < type_anno->fields.size(); i++) {
@@ -185,7 +186,7 @@ class LazyGradientInitializer: public ExprMutator, public
TypeMutator {
Expr VisitExpr_(const ConstantNode* op) final {
return Call(module_->GetConstructor("GradCell", "Raw"),
- {GetRef<Constant>(op)}, Attrs(),
{op->checked_type()});
+ {GetRef<Constant>(op)}, Attrs(), {op->checked_type()});
}
Expr VisitExpr_(const CallNode* call_node) final {
@@ -207,26 +208,25 @@ class LazyGradientInitializer: public ExprMutator, public
TypeMutator {
// call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones") ? "One" :
"Zero";
return Call(module_->GetConstructor("GradCell", constructor_name),
- {func}, Attrs(), {call_node->checked_type()});
+ {func}, Attrs(), {call_node->checked_type()});
}
if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like"))
{
// ones_like and zeros_like need TensorType input
Expr result = CallPrimitiveOp(call_node);
// fn() -> T, function returns result of operation
- Expr func = Function({}, result,
- {call_node->checked_type()}, Array<TypeVar>());
+ Expr func = Function({}, result, {call_node->checked_type()},
Array<TypeVar>());
// call appropriate GradCell constructor
std::string constructor_name = op_expr == Op::Get("ones_like") ? "One"
: "Zero";
return Call(module_->GetConstructor("GradCell", "One"),
- {func}, Attrs(), {call_node->checked_type()});
+ {func}, Attrs(), {call_node->checked_type()});
}
// handle all other ops
Expr result = CallPrimitiveOp(call_node);
// wrap result with Raw constructor
return Call(module_->GetConstructor("GradCell", "Raw"), {result},
- Attrs(), {call_node->checked_type()});
+ Attrs(), {call_node->checked_type()});
}
// not an op
return ExprMutator::VisitExpr_(call_node);
@@ -253,10 +253,11 @@ class LazyGradientInitializer: public ExprMutator, public
TypeMutator {
Expr CallGradCellFunction(const CallNode* call_node, GlobalVar
overloaded_op) {
// can only use overloaded functions if 2 arguments of same type
if (call_node->args.size() != 2 ||
- !AlphaEqual(call_node->args[0]->checked_type(),
call_node->args[1]->checked_type())) {
+ !tvm::StructuralEqual()(call_node->args[0]->checked_type(),
+ call_node->args[1]->checked_type())) {
Expr result = CallPrimitiveOp(call_node);
return Call(module_->GetConstructor("GradCell", "Raw"), {result},
- Attrs(), {call_node->checked_type()});
+ Attrs(), {call_node->checked_type()});
}
tvm::Array<Expr> args;
@@ -266,8 +267,7 @@ class LazyGradientInitializer: public ExprMutator, public
TypeMutator {
Var("rhs", paramType)};
// use primitive op in this case
Expr callOp = Call(call_node->op, {params[0], params[1]});
- Expr func = Function(params, callOp, paramType,
- Array<TypeVar>());
+ Expr func = Function(params, callOp, paramType, Array<TypeVar>());
// pass "fallback" function and tensors as arguments
args.push_back(func);
diff --git a/src/relay/transforms/pattern_util.h
b/src/relay/transforms/pattern_util.h
index e86fcdc..8ce42a2 100644
--- a/src/relay/transforms/pattern_util.h
+++ b/src/relay/transforms/pattern_util.h
@@ -27,6 +27,7 @@
#define TVM_RELAY_TRANSFORMS_PATTERN_UTIL_H_
#include <builtin_fp16.h>
+#include <tvm/node/structural_equal.h>
#include <tvm/tir/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
@@ -300,7 +301,7 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
if (!constant_a || !constant_b || !constant_a->is_scalar() ||
!constant_b->is_scalar()) {
return false;
}
- return AlphaEqual(a, b);
+ return tvm::StructuralEqual()(a, b);
}
inline Expr GetField(Expr t, size_t i) {
diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc
index 1039a1b..e6c8392 100644
--- a/src/relay/transforms/to_cps.cc
+++ b/src/relay/transforms/to_cps.cc
@@ -353,7 +353,7 @@ Function UnCPS(const Function& f) {
auto answer_type = new_type_params.back();
new_type_params.pop_back();
// TODO(@M.K.): make alphaequal work on free term
- // CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type)));
+ // CHECK(tvm::StructuralEqual()(cont_type, Arrow(new_ret_type,
answer_type)));
auto x = Var("x", new_ret_type);
auto cont = Function({x}, x, new_ret_type, {}, {});
tvm::Array<Expr> args;
diff --git a/tests/cpp/relay_pass_alpha_equal.cc
b/tests/cpp/relay_pass_alpha_equal.cc
deleted file mode 100644
index 0207fca..0000000
--- a/tests/cpp/relay_pass_alpha_equal.cc
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <gtest/gtest.h>
-#include <tvm/te/operation.h>
-#include <tvm/relay/expr.h>
-#include <tvm/relay/type.h>
-#include <tvm/relay/analysis.h>
-#include <tvm/relay/transform.h>
-
-using namespace tvm;
-
-class TestAlphaEquals {
- runtime::PackedFunc *_packed_func;
- public:
- TestAlphaEquals(const char* func_name) {
- _packed_func = new runtime::PackedFunc();
- TVMFuncGetGlobal(func_name,
reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
- }
-
- void UpdatePackedFunc(const char* func_name) {
- TVMFuncGetGlobal(func_name,
reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
- }
-
- bool operator()(ObjectRef input_1, ObjectRef input_2) {
- TVMRetValue rv;
- std::vector<TVMValue> values(2);
- std::vector<int> codes(2);
- runtime::TVMArgsSetter setter(values.data(), codes.data());
- setter(0, input_1);
- setter(1, input_2);
- _packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv);
- return bool(rv);
- };
-
-};
-
-TEST(Relay, AlphaTestEmptyTypeNodes) {
- auto x = TypeVar("x", kTypeData);
- auto y = TypeVar();
- EXPECT_FALSE(relay::AlphaEqual(x, y));
-
- TestAlphaEquals test_equals("relay._make._alpha_equal");
- EXPECT_FALSE(test_equals(x, y));
-}
-
-int main(int argc, char ** argv) {
- testing::InitGoogleTest(&argc, argv);
- testing::FLAGS_gtest_death_test_style = "threadsafe";
- return RUN_ALL_TESTS();
-}
diff --git a/tests/cpp/relay_pass_type_infer_test.cc
b/tests/cpp/relay_pass_type_infer_test.cc
index f951a8f..3c41691 100644
--- a/tests/cpp/relay_pass_type_infer_test.cc
+++ b/tests/cpp/relay_pass_type_infer_test.cc
@@ -18,6 +18,7 @@
*/
#include <gtest/gtest.h>
+#include <tvm/node/structural_equal.h>
#include <tvm/te/operation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
@@ -38,7 +39,7 @@ TEST(Relay, SelfReference) {
auto type_fx = mod->Lookup("main");
auto expected = relay::FuncType(tvm::Array<relay::Type>{ tensor_type },
tensor_type, {}, {});
- CHECK(relay::AlphaEqual(type_fx->checked_type(), expected));
+ CHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected));
}
int main(int argc, char ** argv) {
diff --git a/tests/cpp/relay_transform_sequential.cc
b/tests/cpp/relay_transform_sequential.cc
index 756468c..d974f02 100644
--- a/tests/cpp/relay_transform_sequential.cc
+++ b/tests/cpp/relay_transform_sequential.cc
@@ -19,6 +19,7 @@
#include <gtest/gtest.h>
#include <topi/generic/injective.h>
+#include <tvm/node/structural_equal.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h>
#include <tvm/ir/module.h>
@@ -102,7 +103,7 @@ TEST(Relay, Sequential) {
auto mod1 = IRModule::FromExpr(expected_func);
mod1 = relay::transform::InferType()(mod1);
auto expected = mod1->Lookup("main");
- CHECK(relay::AlphaEqual(f, expected));
+ CHECK(tvm::StructuralEqual()(f, expected));
}
int main(int argc, char** argv) {
diff --git a/tests/python/frontend/caffe2/test_graph.py
b/tests/python/frontend/caffe2/test_graph.py
index 35914ec..d64b133 100644
--- a/tests/python/frontend/caffe2/test_graph.py
+++ b/tests/python/frontend/caffe2/test_graph.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Test graph equality of caffe2 models."""
+import tvm
from tvm import relay
from tvm.relay import transform
from model_zoo import c2_squeezenet, relay_squeezenet
@@ -23,7 +24,7 @@ from model_zoo import c2_squeezenet, relay_squeezenet
def compare_graph(lhs_mod, rhs_mod):
lhs_mod = transform.InferType()(lhs_mod)
rhs_mod = transform.InferType()(rhs_mod)
- assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"])
+ assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"])
def test_squeeze_net():
diff --git a/tests/python/frontend/mxnet/test_graph.py
b/tests/python/frontend/mxnet/test_graph.py
index 0008799..b7c01a5 100644
--- a/tests/python/frontend/mxnet/test_graph.py
+++ b/tests/python/frontend/mxnet/test_graph.py
@@ -25,7 +25,7 @@ import model_zoo
def compare_graph(lhs_mod, rhs_mod):
lhs_mod = transform.InferType()(lhs_mod)
rhs_mod = transform.InferType()(rhs_mod)
- assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"])
+ assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"])
def test_mlp():
shape = {"data": (1, 1, 28, 28)}
diff --git a/tests/python/relay/test_analysis_extract_fused_functions.py
b/tests/python/relay/test_analysis_extract_fused_functions.py
index 1a70ef1..dab481c 100644
--- a/tests/python/relay/test_analysis_extract_fused_functions.py
+++ b/tests/python/relay/test_analysis_extract_fused_functions.py
@@ -77,7 +77,7 @@ def test_extract_identity():
mod["main"] = mod["main"].with_attr(
"Primitive", tvm.tir.IntImm("int32", 1))
- relay.analysis.assert_graph_equal(list(items.values())[0], mod["main"])
+ tvm.ir.structural_equal(list(items.values())[0], mod["main"])
def test_extract_conv_net():
diff --git a/tests/python/relay/test_annotate_target.py
b/tests/python/relay/test_annotate_target.py
index f4e602a..12a15dc 100644
--- a/tests/python/relay/test_annotate_target.py
+++ b/tests/python/relay/test_annotate_target.py
@@ -136,7 +136,7 @@ def test_extern_dnnl():
mod = annotated(dtype, ishape, w1shape)
mod = transform.AnnotateTarget("dnnl")(mod)
ref_mod = expected(dtype, ishape, w1shape)
- assert relay.analysis.alpha_equal(mod, ref_mod)
+ assert tvm.ir.structural_equal(mod, ref_mod)
def test_run():
if not tvm.get_global_func("relay.ext.dnnl", True):
diff --git a/tests/python/relay/test_call_graph.py
b/tests/python/relay/test_call_graph.py
index 849f015..0af55d2 100644
--- a/tests/python/relay/test_call_graph.py
+++ b/tests/python/relay/test_call_graph.py
@@ -27,7 +27,7 @@ def test_callgraph_construct():
mod["g1"] = relay.Function([x, y], x + y)
call_graph = relay.analysis.CallGraph(mod)
assert "g1" in str(call_graph)
- assert relay.alpha_equal(mod, call_graph.module)
+ assert tvm.ir.structural_equal(mod, call_graph.module)
def test_print_element():
diff --git a/tests/python/relay/test_ir_bind.py
b/tests/python/relay/test_ir_bind.py
index 45474b6..8ba4644 100644
--- a/tests/python/relay/test_ir_bind.py
+++ b/tests/python/relay/test_ir_bind.py
@@ -29,11 +29,11 @@ def test_bind_params():
fexpected =relay.Function(
[y],
relay.add(relay.const(1, "float32"), y))
- assert relay.analysis.alpha_equal(fbinded, fexpected)
+ assert tvm.ir.structural_equal(fbinded, fexpected)
zbinded = relay.bind(z, {y: x})
zexpected = relay.add(x, x)
- assert relay.analysis.alpha_equal(zbinded, zexpected)
+ assert tvm.ir.structural_equal(zbinded, zexpected)
if __name__ == "__main__":
diff --git a/tests/python/relay/test_ir_nodes.py
b/tests/python/relay/test_ir_nodes.py
index 968a3bb..6d4a685 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -21,13 +21,12 @@ from tvm import te
from tvm import relay
from tvm.tir.expr import *
from tvm.relay import op
-from tvm.relay.analysis import graph_equal
import numpy as np
def check_json_roundtrip(node):
json_str = tvm.ir.save_json(node)
back = tvm.ir.load_json(json_str)
- assert graph_equal(back, node)
+ assert tvm.ir.structural_equal(back, node, map_free_vars=True)
# Span
diff --git a/tests/python/relay/test_ir_structural_equal_hash.py
b/tests/python/relay/test_ir_structural_equal_hash.py
index cf626d7..5295e17 100644
--- a/tests/python/relay/test_ir_structural_equal_hash.py
+++ b/tests/python/relay/test_ir_structural_equal_hash.py
@@ -107,7 +107,7 @@ def test_func_type_sequal():
ft = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
tvm.runtime.convert([tp1, tp3]),
tvm.runtime.convert([tr1]))
- translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
+ translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp2,
tvm.runtime.convert([tp2, tp4]),
tvm.runtime.convert([tr2]))
assert ft == translate_vars
diff --git a/tests/python/relay/test_ir_text_printer.py
b/tests/python/relay/test_ir_text_printer.py
index 49518a8..61dbca3 100644
--- a/tests/python/relay/test_ir_text_printer.py
+++ b/tests/python/relay/test_ir_text_printer.py
@@ -20,7 +20,7 @@ from tvm import relay
import tvm.relay.testing
import numpy as np
from tvm.relay import Expr
-from tvm.relay.analysis import alpha_equal, assert_alpha_equal,
assert_graph_equal, free_vars
+from tvm.relay.analysis import free_vars
do_print = [False]
@@ -32,9 +32,9 @@ def astext(p, unify_free_vars=False):
return txt
x = relay.fromtext(txt)
if unify_free_vars:
- assert_graph_equal(x, p)
+ tvm.ir.assert_structural_equal(x, p, map_free_vars=True)
else:
- assert_alpha_equal(x, p)
+ tvm.ir.assert_structural_equal(x, p)
return txt
def show(text):
diff --git a/tests/python/relay/test_op_level10.py
b/tests/python/relay/test_op_level10.py
index 953760c..30e2506 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -99,7 +99,7 @@ def test_checkpoint_alpha_equal():
"""
)
- relay.analysis.assert_alpha_equal(df, df_parsed)
+ tvm.ir.assert_structural_equal(df, df_parsed)
def test_checkpoint_alpha_equal_tuple():
xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i
in range(4)]
@@ -146,7 +146,7 @@ def test_checkpoint_alpha_equal_tuple():
"""
)
- relay.analysis.assert_alpha_equal(df, df_parsed)
+ tvm.ir.assert_structural_equal(df, df_parsed)
def test_collapse_sum_like():
shape = (3, 4, 5, 6)
diff --git a/tests/python/relay/test_pass_alter_op_layout.py
b/tests/python/relay/test_pass_alter_op_layout.py
index eabe758..a30492f 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -66,7 +66,7 @@ def test_alter_op():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_return_none():
@@ -88,7 +88,7 @@ def test_alter_return_none():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(before(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
assert(called[0])
@@ -151,7 +151,7 @@ def test_alter_layout():
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_dual_path():
@@ -214,7 +214,7 @@ def test_alter_layout_dual_path():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_resnet():
"""Test alternating the layout of a residual block
@@ -271,7 +271,7 @@ def test_alter_layout_resnet():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_broadcast_op():
@@ -318,7 +318,7 @@ def test_alter_layout_broadcast_op():
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_broadcast_scalar_op():
@@ -381,7 +381,7 @@ def test_alter_layout_broadcast_scalar_op():
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_scalar():
@@ -424,7 +424,7 @@ def test_alter_layout_scalar():
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_concatenate():
@@ -478,7 +478,7 @@ def test_alter_layout_concatenate():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# NHWC layout transformation.
def before_nhwc():
@@ -524,7 +524,7 @@ def test_alter_layout_concatenate():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_nchw_upsamping_op():
@@ -561,7 +561,7 @@ def test_alter_layout_nchw_upsamping_op():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_strided_slice():
@@ -597,7 +597,7 @@ def test_alter_layout_strided_slice():
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_depthwise_conv2d():
"""Test depthwise_conv2d operator"""
@@ -632,7 +632,7 @@ def test_alter_layout_depthwise_conv2d():
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
- assert(analysis.alpha_equal(a, b))
+ assert(tvm.ir.structural_equal(a, b))
def test_alter_layout_prelu():
"""Test PRelu operator"""
@@ -672,7 +672,7 @@ def test_alter_layout_prelu():
a = run_opt_pass(a, [transform.CanonicalizeOps(),
transform.AlterOpLayout()])
b = run_opt_pass(expected(), transform.InferType())
- assert(analysis.alpha_equal(a, b))
+ assert(tvm.ir.structural_equal(a, b))
def test_alter_layout_pad():
@@ -715,7 +715,7 @@ def test_alter_layout_pad():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check NHWC conversion.
def before_nhwc():
@@ -749,7 +749,7 @@ def test_alter_layout_pad():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check that conversion does not happen when padding along split axis.
def before():
@@ -782,7 +782,7 @@ def test_alter_layout_pad():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_pool():
@@ -825,7 +825,7 @@ def test_alter_layout_pool():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check NHWC conversion.
def before_nhwc():
@@ -859,7 +859,7 @@ def test_alter_layout_pool():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_layout_sum():
@@ -902,7 +902,7 @@ def test_alter_layout_sum():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nchw(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check NHWC conversion.
def before_nhwc():
@@ -937,7 +937,7 @@ def test_alter_layout_sum():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# TODO(@anijain2305, @icemelon9): We should fix this. This doesn't seem to be
the
@@ -999,7 +999,7 @@ def test_alter_layout_nhwc_nchw_arm():
a = run_opt_pass(a, transform.AlterOpLayout())
b = run_opt_pass(expected_nhwc(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_alter_op_with_global_var():
"""Test directly replacing an operator with a new one"""
@@ -1041,7 +1041,7 @@ def test_alter_op_with_global_var():
a = transform.AlterOpLayout()(a)
b = transform.InferType()(expected())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" +
str(a)
if __name__ == "__main__":
test_alter_op()
diff --git a/tests/python/relay/test_pass_annotation.py
b/tests/python/relay/test_pass_annotation.py
index 3e7d916..ea92546 100644
--- a/tests/python/relay/test_pass_annotation.py
+++ b/tests/python/relay/test_pass_annotation.py
@@ -64,7 +64,7 @@ def test_redundant_annotation():
annotated_func = annotated()
expected_func = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.alpha_equal(annotated_func, expected_func)
+ assert tvm.ir.structural_equal(annotated_func, expected_func)
def test_annotate_expr():
@@ -91,7 +91,7 @@ def test_annotate_expr():
annotated_expr = annotated()
expected_expr = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.graph_equal(annotated_expr, expected_expr)
+ assert tvm.ir.structural_equal(annotated_expr, expected_expr)
def test_annotate_all():
@@ -120,7 +120,7 @@ def test_annotate_all():
annotated_func = annotated()
expected_func = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.graph_equal(annotated_func, expected_func)
+ assert tvm.ir.structural_equal(annotated_func, expected_func)
def test_annotate_none():
@@ -146,13 +146,13 @@ def test_annotate_none():
annotated_func = annotated()
expected_func = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.graph_equal(annotated_func, expected_func)
+ assert tvm.ir.structural_equal(annotated_func, expected_func)
def check_annotated_graph(annotated_func, expected_func):
annotated_func = run_opt_pass(annotated_func, transform.InferType())
expected_func = run_opt_pass(expected_func, transform.InferType())
- assert relay.analysis.alpha_equal(annotated_func, expected_func)
+ assert tvm.ir.structural_equal(annotated_func, expected_func)
def test_conv_network():
@@ -596,7 +596,7 @@ def test_tuple_get_item():
annotated_func = annotated()
expected_func = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.graph_equal(annotated_func, expected_func)
+ assert tvm.ir.structural_equal(annotated_func, expected_func)
if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_canonicalize_cast.py
b/tests/python/relay/test_pass_canonicalize_cast.py
index e9ab67f..7b6617a 100644
--- a/tests/python/relay/test_pass_canonicalize_cast.py
+++ b/tests/python/relay/test_pass_canonicalize_cast.py
@@ -64,7 +64,7 @@ def test_canonicalize_cast():
mod[gv] = y_expected
mod = _transform.InferType()(mod)
y_expected = mod["expected"]
- assert relay.analysis.alpha_equal(y, y_expected)
+ assert tvm.ir.structural_equal(y, y_expected)
check((1, 16, 7, 7))
diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py
b/tests/python/relay/test_pass_combine_parallel_conv2d.py
index ec9bcd9..345f068 100644
--- a/tests/python/relay/test_pass_combine_parallel_conv2d.py
+++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py
@@ -72,7 +72,7 @@ def test_combine_parallel_conv2d():
transform.CombineParallelConv2D(min_num_branches=2))
y_expected = expected(x, w1, w2, w3, w4, channels1, channels2,
channels3, channels4)
y_expected = run_opt_pass(y_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y, y_expected)
+ assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
check((1, 4, 16, 16), 4, 4, 4, 4)
check((1, 4, 16, 16), 4, 8, 4, 7)
@@ -118,7 +118,7 @@ def test_combine_parallel_conv2d_scale_relu():
transform.CombineParallelConv2D(min_num_branches=2))
y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1,
channels2)
y_expected = run_opt_pass(y_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y, y_expected)
+ assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
check((1, 4, 16, 16), 4, 8)
@@ -157,7 +157,7 @@ def test_combine_parallel_conv2d_scale():
transform.CombineParallelConv2D(min_num_branches=2))
y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
y_expected = run_opt_pass(y_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y, y_expected)
+ assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
check((1, 4, 16, 16), 4, 8)
@@ -193,7 +193,7 @@ def test_combine_parallel_conv2d_multiple_blocks():
transform.CombineParallelConv2D(min_num_branches=2))
y_expected = expected(x, w, out_c, repeat)
y_expected = run_opt_pass(y_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y, y_expected)
+ assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
check((1, 4, 16, 16), 4)
diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py
b/tests/python/relay/test_pass_combine_parallel_dense.py
index 84d8211..f0f2e18 100644
--- a/tests/python/relay/test_pass_combine_parallel_dense.py
+++ b/tests/python/relay/test_pass_combine_parallel_dense.py
@@ -75,7 +75,7 @@ def test_combine_parallel_dense():
transform.CombineParallelDense(min_num_branches=2))
y_expected = expected(x, w1, w2, w3, w4)
y_expected = run_opt_pass(y_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y, y_expected)
+ tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
check(3, 5, 4)
check(100, 200, 300)
@@ -127,7 +127,7 @@ def test_combine_parallel_dense_biasadd():
transform.CombineParallelDense(min_num_branches=2))
y_expected = expected(x, w1, w2, b1, b2, is_2d_bias)
y_expected = run_opt_pass(y_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y, y_expected)
+ tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
check(3, 5, 4, False)
check(100, 200, 300, False)
@@ -184,7 +184,7 @@ def test_combine_parallel_dense_biasadd_scale_reshape():
transform.CombineParallelDense(min_num_branches=2))
y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape)
y_expected = run_opt_pass(y_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y, y_expected)
+ tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
check(3, 5, 4, 0.5, 0.25, (1, 1, 15))
check(100, 200, 300, 0.5, 0.25, (1, 1, 200))
diff --git a/tests/python/relay/test_pass_convert_op_layout.py
b/tests/python/relay/test_pass_convert_op_layout.py
index 9e8f662..c783971 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -52,7 +52,7 @@ def test_no_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_conv_convert_layout():
@@ -87,7 +87,7 @@ def test_conv_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_conv_bias_pool_convert_layout():
@@ -132,7 +132,7 @@ def test_conv_bias_pool_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_conv_concat_convert_layout():
@@ -180,7 +180,7 @@ def test_conv_concat_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_dual_path_convert_layout():
@@ -235,7 +235,7 @@ def test_dual_path_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_bn_convert_layout():
@@ -315,7 +315,7 @@ def test_resnet_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_scalar_convert_layout():
@@ -347,7 +347,7 @@ def test_scalar_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_conv_bn_convert_layout():
@@ -395,7 +395,7 @@ def test_conv_bn_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_qnn_conv_requantize_convert_layout():
@@ -451,7 +451,7 @@ def test_qnn_conv_requantize_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_qnn_conv_concat_convert_layout():
@@ -529,7 +529,7 @@ def test_qnn_conv_concat_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_qnn_conv_add_convert_layout():
@@ -609,7 +609,7 @@ def test_qnn_conv_add_convert_layout():
a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_dead_code_elimination.py
b/tests/python/relay/test_pass_dead_code_elimination.py
index 3a0bf1f..60dfa62 100644
--- a/tests/python/relay/test_pass_dead_code_elimination.py
+++ b/tests/python/relay/test_pass_dead_code_elimination.py
@@ -18,7 +18,7 @@ import tvm
from tvm import te
from tvm import relay
from tvm.relay import Function, transform
-from tvm.relay.analysis import alpha_equal, graph_equal, free_vars,
assert_alpha_equal
+from tvm.relay.analysis import free_vars
from tvm.relay.op import log, add, equal, subtract
from tvm.relay.testing import inception_v3
@@ -69,7 +69,7 @@ def test_used_let():
def test_inline():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
- assert_alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
+ tvm.ir.assert_structural_equal(Function(free_vars(orig), orig),
Function([e.d], e.d))
def test_chain_unused_let():
@@ -105,7 +105,7 @@ def test_recursion():
orig = use_f(lambda f: relay.Call(f, [relay.const(2),
relay.const(10000.0)]))
dced = run_opt_pass(orig, transform.DeadCodeElimination())
orig = run_opt_pass(orig, transform.InferType())
- assert_alpha_equal(dced, orig)
+ tvm.ir.assert_structural_equal(dced, orig)
def test_recursion_dead():
x = relay.Let(e.a, e.one, e.three)
diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py
b/tests/python/relay/test_pass_eliminate_common_subexpr.py
index dddbef7..89e3b67 100644
--- a/tests/python/relay/test_pass_eliminate_common_subexpr.py
+++ b/tests/python/relay/test_pass_eliminate_common_subexpr.py
@@ -52,7 +52,7 @@ def test_simple():
z = before()
z = run_opt_pass(z, transform.EliminateCommonSubexpr())
- assert analysis.alpha_equal(z, expected())
+ assert tvm.ir.structural_equal(z, expected())
def test_callback():
@@ -82,7 +82,7 @@ def test_callback():
z = before()
z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip))
- assert analysis.alpha_equal(z, expected())
+ assert tvm.ir.structural_equal(z, expected())
if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_eta_expand.py
b/tests/python/relay/test_pass_eta_expand.py
index ad04e41..84ff54a 100644
--- a/tests/python/relay/test_pass_eta_expand.py
+++ b/tests/python/relay/test_pass_eta_expand.py
@@ -47,7 +47,8 @@ def test_eta_expand_global_var():
}
}
""")
- relay.analysis.assert_graph_equal(mod['main'], expected['main'])
+ tvm.ir.assert_structural_equal(mod['main'], expected['main'],
+ map_free_vars=True)
def test_eta_expand_constructor():
@@ -76,7 +77,8 @@ def test_eta_expand_constructor():
}
}
""")
- relay.analysis.assert_graph_equal(mod['main'], expected['main'])
+ tvm.ir.assert_structural_equal(mod['main'], expected['main'],
+ map_free_vars=True)
if __name__ == '__main__':
diff --git a/tests/python/relay/test_pass_fold_constant.py
b/tests/python/relay/test_pass_fold_constant.py
index cc362a2..3ddafd7 100644
--- a/tests/python/relay/test_pass_fold_constant.py
+++ b/tests/python/relay/test_pass_fold_constant.py
@@ -59,7 +59,7 @@ def test_fold_const():
with tvm.target.create("cuda"):
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.alpha_equal(zz, zexpected)
+ assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_let():
@@ -84,7 +84,7 @@ def test_fold_let():
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.graph_equal(zz, zexpected)
+ assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_tuple():
@@ -106,7 +106,7 @@ def test_fold_tuple():
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.graph_equal(zz, zexpected)
+ assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_concat():
@@ -125,7 +125,7 @@ def test_fold_concat():
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.graph_equal(zz, zexpected)
+ assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_shape_of():
@@ -146,7 +146,7 @@ def test_fold_shape_of():
for dtype in ["int32", "float32"]:
zz = run_opt_pass(before(dtype), transform.FoldConstant())
zexpected = run_opt_pass(expected(dtype), transform.InferType())
- assert relay.analysis.graph_equal(zz, zexpected)
+ assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_full():
@@ -161,7 +161,7 @@ def test_fold_full():
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.graph_equal(zz, zexpected)
+ assert tvm.ir.structural_equal(zz, zexpected)
def test_fold_batch_norm():
@@ -202,7 +202,7 @@ def test_fold_batch_norm():
mod = remove_bn_pass(mod)
expect = run_infer_type(expected())
- assert relay.analysis.graph_equal(mod["main"], expect)
+ assert tvm.ir.structural_equal(mod["main"], expect)
if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_fold_scale_axis.py
b/tests/python/relay/test_pass_fold_scale_axis.py
index 4c094fb..bf2a708 100644
--- a/tests/python/relay/test_pass_fold_scale_axis.py
+++ b/tests/python/relay/test_pass_fold_scale_axis.py
@@ -79,7 +79,7 @@ def test_fold_fwd_simple():
y1_folded = run_opt_pass(y1_folded, transform.InferType())
y1_expected = run_opt_pass(y1_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+ assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 2)
@@ -148,7 +148,7 @@ def test_fold_fwd_dual_path():
weight = relay.var("weight", type_dict["weight"])
y1_expected = expected(x, weight, in_bias, in_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+ assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 3), 3)
@@ -177,7 +177,7 @@ def test_fold_fwd_fail():
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
- assert relay.analysis.alpha_equal(y1, y1_folded)
+ assert tvm.ir.structural_equal(y1, y1_folded)
check((2, 11, 10, 4), 4)
@@ -205,7 +205,7 @@ def test_fold_fwd_relu_fail():
y1 = before(x, weight, in_bias, in_scale, channels)
y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
- assert relay.analysis.alpha_equal(y1, y1_folded)
+ assert tvm.ir.structural_equal(y1, y1_folded)
in_scale = relay.var("in_scale", shape=(4,))
check((2, 11, 10, 4), 4, in_scale)
@@ -249,7 +249,7 @@ def test_fold_fwd_negative_scale():
y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
y1_expected = expected(x, weight, in_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+ assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 4)
@@ -300,7 +300,7 @@ def test_fold_bwd_simple():
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+ assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
@@ -359,7 +359,7 @@ def test_fold_bwd_dual_path():
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+ assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
@@ -431,7 +431,7 @@ def test_fold_bwd_dual_consumer():
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_bias, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+ assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 4)
@@ -480,7 +480,7 @@ def test_fold_bwd_fail():
y1 = fbefore(x, weight, out_bias, out_scale, channels)
y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
- assert relay.analysis.alpha_equal(y1_folded, y1)
+ assert tvm.ir.structural_equal(y1_folded, y1)
check((4, 4, 10, 10), 4, fail1)
check((4, 4, 10, 10), 4, fail2)
@@ -505,7 +505,7 @@ def test_fold_bwd_relu_fail():
y1 = before(x, weight, out_scale, channels)
y1 = run_opt_pass(y1, transform.InferType())
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
- assert relay.analysis.alpha_equal(y1, y1_folded)
+ assert tvm.ir.structural_equal(y1, y1_folded)
out_scale = relay.var("in_scale", shape=(4, 1, 1))
check((4, 4, 10, 10), 4, out_scale)
@@ -547,7 +547,7 @@ def test_fold_bwd_negative_scale():
y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
y1_expected = expected(x, weight, out_scale, channels)
y1_expected = run_opt_pass(y1_expected, transform.InferType())
- assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+ assert tvm.ir.structural_equal(y1_folded, y1_expected)
check((2, 4, 10, 10), 8)
diff --git a/tests/python/relay/test_pass_fuse_ops.py
b/tests/python/relay/test_pass_fuse_ops.py
index 108c91b..6b7d297 100644
--- a/tests/python/relay/test_pass_fuse_ops.py
+++ b/tests/python/relay/test_pass_fuse_ops.py
@@ -45,7 +45,7 @@ def test_fuse_simple():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.alpha_equal(zz, after)
+ assert tvm.ir.structural_equal(zz, after)
def test_conv2d_fuse():
@@ -127,7 +127,7 @@ def test_conv2d_fuse():
z = before(dshape)
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
after = run_opt_pass(expected(dshape), transform.InferType())
- assert relay.analysis.alpha_equal(zz, after)
+ assert tvm.ir.structural_equal(zz, after)
def test_concatenate():
@@ -167,7 +167,7 @@ def test_concatenate():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(dshape), transform.InferType())
- assert relay.analysis.alpha_equal(zz, after)
+ assert tvm.ir.structural_equal(zz, after)
def test_tuple_root():
@@ -204,7 +204,7 @@ def test_tuple_root():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(dshape), transform.InferType())
- assert relay.analysis.alpha_equal(zz, after)
+ assert tvm.ir.structural_equal(zz, after)
def test_stop_fusion():
@@ -235,7 +235,7 @@ def test_stop_fusion():
z = before(dshape)
zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(dshape), transform.InferType())
- assert relay.analysis.alpha_equal(zz, after)
+ assert tvm.ir.structural_equal(zz, after)
def test_fuse_myia_regression():
@@ -271,7 +271,7 @@ def test_fuse_myia_regression():
f = before(dshape, dtype)
zz = run_opt_pass(f, transform.FuseOps())
after = run_opt_pass(expected(dshape, dtype), transform.InferType())
- assert relay.analysis.alpha_equal(zz, after)
+ assert tvm.ir.structural_equal(zz, after)
def test_fuse_tuple_get_elemwise():
@@ -309,7 +309,7 @@ def test_fuse_tuple_get_elemwise():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(dim), transform.InferType())
- assert relay.analysis.alpha_equal(zz, after)
+ assert tvm.ir.structural_equal(zz, after)
def test_tuple_get_root():
@@ -346,7 +346,7 @@ def test_tuple_get_root():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(dim), transform.InferType())
- assert relay.analysis.alpha_equal(zz, after)
+ assert tvm.ir.structural_equal(zz, after)
fuse0 = relay.transform.FuseOps(fuse_opt_level=0)
@@ -379,7 +379,7 @@ def test_tuple_intermediate():
m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm')
after = run_opt_pass(expected(x), transform.InferType())
- assert relay.analysis.alpha_equal(m["main"], after)
+ assert tvm.ir.structural_equal(m["main"], after)
def test_tuple_consecutive():
@@ -437,7 +437,7 @@ def test_tuple_consecutive():
m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm')
after = run_opt_pass(expected(dshape), transform.InferType())
- assert relay.analysis.alpha_equal(m["main"], after)
+ assert tvm.ir.structural_equal(m["main"], after)
def test_inception_like():
@@ -510,7 +510,7 @@ def test_inception_like():
m = fuse2(tvm.IRModule.from_expr(orig))
relay.build(m, 'llvm')
after = run_opt_pass(expected(dshape), transform.InferType())
- assert relay.analysis.alpha_equal(m["main"], after)
+ assert tvm.ir.structural_equal(m["main"], after)
def test_fuse_parallel_injective():
@@ -541,7 +541,7 @@ def test_fuse_parallel_injective():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
assert not relay.analysis.free_vars(zz)
after = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.alpha_equal(zz, after)
+ assert tvm.ir.structural_equal(zz, after)
def test_immutable():
@@ -570,8 +570,8 @@ def test_immutable():
mod = before()
new_mod = transform.FuseOps(fuse_opt_level=2)(mod)
- assert relay.analysis.alpha_equal(mod, before())
- assert relay.analysis.alpha_equal(new_mod, expected())
+ assert tvm.ir.structural_equal(mod, before())
+ assert tvm.ir.structural_equal(new_mod, expected())
def test_split():
@@ -619,7 +619,7 @@ def test_fuse_max():
zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
zz = run_opt_pass(z, transform.FuseOps())
after = run_opt_pass(expected(), transform.InferType())
- assert relay.analysis.alpha_equal(zz, after)
+ assert tvm.ir.structural_equal(zz, after)
if __name__ == "__main__":
test_fuse_simple()
diff --git a/tests/python/relay/test_pass_gradient.py
b/tests/python/relay/test_pass_gradient.py
index 6f2a125..efd01cb 100644
--- a/tests/python/relay/test_pass_gradient.py
+++ b/tests/python/relay/test_pass_gradient.py
@@ -19,7 +19,7 @@ import numpy as np
import tvm
from tvm import te
from tvm import relay
-from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal
+from tvm.relay.analysis import free_vars, free_type_vars
from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient
from tvm.relay.prelude import Prelude
@@ -292,7 +292,7 @@ def test_concat():
func = relay.Function([x], y)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func))
- assert_alpha_equal(back_func.checked_type, relay.FuncType([t],
relay.TupleType([rt, relay.TupleType([t])])))
+ tvm.ir.assert_structural_equal(back_func.checked_type, relay.FuncType([t],
relay.TupleType([rt, relay.TupleType([t])])))
# no value validation as concatenate has dummy gradient right now.
diff --git a/tests/python/relay/test_pass_inline.py
b/tests/python/relay/test_pass_inline.py
index f4943ab..0f6d539 100644
--- a/tests/python/relay/test_pass_inline.py
+++ b/tests/python/relay/test_pass_inline.py
@@ -115,7 +115,7 @@ def test_call_chain_inline_leaf():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, expected())
+ assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_call_chain_inline_multiple_levels():
@@ -188,7 +188,7 @@ def test_call_chain_inline_multiple_levels():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, expected())
+ assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_call_chain_inline_multiple_levels_extern_compiler():
@@ -266,7 +266,7 @@ def
test_call_chain_inline_multiple_levels_extern_compiler():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, expected())
+ assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_recursive_call_with_global():
@@ -321,7 +321,7 @@ def test_recursive_call_with_global():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, expected())
+ assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_recursive_called():
@@ -330,7 +330,7 @@ def test_recursive_called():
mod["main"] = relay.Function([iarg], sum_up(iarg))
ref_mod = mod
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, ref_mod)
+ assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
def test_recursive_not_called():
@@ -356,7 +356,7 @@ def test_recursive_not_called():
mod = get_mod()
mod = relay.transform.Inline()(mod)
ref_mod = expected()
- assert relay.analysis.alpha_equal(mod, ref_mod)
+ assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
def test_recursive_not_called_extern_compiler():
@@ -387,7 +387,7 @@ def test_recursive_not_called_extern_compiler():
mod = get_mod()
mod = relay.transform.Inline()(mod)
ref_mod = expected()
- assert relay.analysis.alpha_equal(mod, ref_mod)
+ assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
def test_globalvar_as_call_arg():
@@ -434,7 +434,7 @@ def test_globalvar_as_call_arg():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, expected())
+ assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_globalvar_as_call_arg_extern_compiler():
@@ -500,7 +500,7 @@ def test_globalvar_as_call_arg_extern_compiler():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, expected())
+ assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_inline_globalvar_without_args():
@@ -531,7 +531,7 @@ def test_inline_globalvar_without_args():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, expected())
+ assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_inline_globalvar_without_args_extern_compiler():
@@ -566,7 +566,7 @@ def test_inline_globalvar_without_args_extern_compiler():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, expected())
+ assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_globalvar_called_by_multiple_functions():
@@ -644,7 +644,7 @@ def test_globalvar_called_by_multiple_functions():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, expected())
+ assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_entry_with_inline():
@@ -674,7 +674,7 @@ def test_entry_with_inline():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, get_mod())
+ assert tvm.ir.structural_equal(mod, get_mod(), map_free_vars=True)
def test_callee_not_inline():
@@ -707,7 +707,7 @@ def test_callee_not_inline():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, get_mod())
+ assert tvm.ir.structural_equal(mod, get_mod(), map_free_vars=True)
def test_callee_not_inline_leaf_inline():
@@ -765,7 +765,7 @@ def test_callee_not_inline_leaf_inline():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, expected())
+ assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
def test_callee_not_inline_leaf_inline_extern_compiler():
@@ -830,7 +830,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler():
mod = get_mod()
mod = relay.transform.Inline()(mod)
- assert relay.analysis.alpha_equal(mod, expected())
+ assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
if __name__ == '__main__':
diff --git a/tests/python/relay/test_pass_legalize.py
b/tests/python/relay/test_pass_legalize.py
index 9976eca..1456700 100644
--- a/tests/python/relay/test_pass_legalize.py
+++ b/tests/python/relay/test_pass_legalize.py
@@ -68,7 +68,7 @@ def test_legalize():
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_legalize_none():
"""Test doing nothing by returning 'None' """
@@ -89,7 +89,7 @@ def test_legalize_none():
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(before(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
assert(called[0])
def test_legalize_multiple_ops():
@@ -134,7 +134,7 @@ def test_legalize_multiple_ops():
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_legalize_multi_input():
@@ -170,7 +170,7 @@ def test_legalize_multi_input():
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_manager.py
b/tests/python/relay/test_pass_manager.py
index aed0269..0a6555b 100644
--- a/tests/python/relay/test_pass_manager.py
+++ b/tests/python/relay/test_pass_manager.py
@@ -111,7 +111,7 @@ def get_rand(shape, dtype='float32'):
def check_func(func, ref_func):
func = run_infer_type(func)
ref_func = run_infer_type(ref_func)
- assert analysis.graph_equal(func, ref_func)
+ assert tvm.ir.structural_equal(func, ref_func)
def test_module_pass():
@@ -211,7 +211,7 @@ def test_function_class_pass():
mod = fpass(mod)
# wrap in expr
mod2 = tvm.IRModule.from_expr(f1)
- assert relay.alpha_equal(mod["main"], mod2["main"])
+ assert tvm.ir.structural_equal(mod["main"], mod2["main"])
def test_function_pass():
@@ -496,7 +496,7 @@ def test_sequential_with_scoping():
zz = mod["main"]
zexpected = run_infer_type(expected())
- assert analysis.alpha_equal(zz, zexpected)
+ assert tvm.ir.structural_equal(zz, zexpected)
def test_print_ir(capfd):
diff --git a/tests/python/relay/test_pass_merge_composite.py
b/tests/python/relay/test_pass_merge_composite.py
index 72ed3fc..3c70cf2 100644
--- a/tests/python/relay/test_pass_merge_composite.py
+++ b/tests/python/relay/test_pass_merge_composite.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Unit tests for merge composite."""
+import tvm
from tvm import relay
from tvm import tir
from tvm.relay.testing import run_opt_pass
@@ -192,7 +193,7 @@ def test_simple_merge():
result = run_opt_pass(before(),
relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
- assert relay.analysis.alpha_equal(result, expected)
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_branch_merge():
@@ -270,7 +271,7 @@ def test_branch_merge():
result = run_opt_pass(before(),
relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
- assert relay.analysis.alpha_equal(result, expected)
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_reuse_call_merge():
@@ -329,7 +330,7 @@ def test_reuse_call_merge():
result = run_opt_pass(before(),
relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
- assert relay.analysis.alpha_equal(result, expected)
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_multiple_patterns():
@@ -422,7 +423,7 @@ def test_multiple_patterns():
result = run_opt_pass(before(),
relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
- assert relay.analysis.alpha_equal(result, expected)
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_merge_order():
@@ -494,7 +495,7 @@ def test_merge_order():
result = run_opt_pass(before(),
relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority("A"), relay.transform.InferType())
- assert relay.analysis.alpha_equal(result, expected)
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
# check B highest priority
pattern_table = [
@@ -505,7 +506,7 @@ def test_merge_order():
result = run_opt_pass(before(),
relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority("B"), relay.transform.InferType())
- assert relay.analysis.alpha_equal(result, expected)
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
# check C highest priority
pattern_table = [
@@ -516,7 +517,7 @@ def test_merge_order():
result = run_opt_pass(before(),
relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A_priority("C"), relay.transform.InferType())
- assert relay.analysis.alpha_equal(result, expected)
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_parallel_merge():
@@ -563,7 +564,7 @@ def test_parallel_merge():
result = run_opt_pass(before(),
relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after(), relay.transform.InferType())
- assert relay.analysis.alpha_equal(result, expected)
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_multiple_input_subgraphs():
@@ -676,13 +677,13 @@ def test_multiple_input_subgraphs():
result = run_opt_pass(before()['A'],
relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_A(), relay.transform.InferType())
- assert relay.analysis.alpha_equal(result, expected)
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
# check case 'B'
result = run_opt_pass(before()['B'],
relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(after_B(), relay.transform.InferType())
- assert relay.analysis.alpha_equal(result, expected)
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
def test_tuple_get_item_merge():
@@ -728,7 +729,7 @@ def test_tuple_get_item_merge():
result = run_opt_pass(before(),
relay.transform.MergeComposite(pattern_table))
assert not relay.analysis.free_vars(result)
expected = run_opt_pass(expected(), relay.transform.InferType())
- assert relay.analysis.alpha_equal(result, expected)
+ assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_partial_eval.py
b/tests/python/relay/test_pass_partial_eval.py
index 1299084..0f3eea6 100644
--- a/tests/python/relay/test_pass_partial_eval.py
+++ b/tests/python/relay/test_pass_partial_eval.py
@@ -19,7 +19,6 @@ import numpy as np
import tvm
from tvm import te
from tvm import relay
-from tvm.relay.analysis import alpha_equal, assert_alpha_equal
from tvm.relay.prelude import Prelude
from tvm.relay import op, create_executor, transform
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const,
RefRead, RefWrite, RefCreate
@@ -124,7 +123,7 @@ def test_ad():
body = relay.Let(x1, o, body)
expected = Function([d], relay.Let(x, m, body))
expected = run_opt_pass(expected, transform.InferType())
- assert_alpha_equal(g, expected)
+ tvm.ir.assert_structural_equal(g, expected)
def test_if_ref():
@@ -312,7 +311,7 @@ def test_concat():
x = Var("x", t)
y = Var("x", t)
orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0)))
- assert_alpha_equal(dcpe(orig), orig)
+ tvm.ir.assert_structural_equal(dcpe(orig), orig)
def test_triangle_number():
@@ -321,7 +320,7 @@ def test_triangle_number():
f_var = Var("f")
f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x -
const(1))))
orig = run_infer_type(Let(f_var, f, f_var(const(10))))
- assert_alpha_equal(dcpe(orig), const(55))
+ tvm.ir.assert_structural_equal(dcpe(orig), const(55))
def test_nat_update():
@@ -337,7 +336,7 @@ def test_tuple_match():
b = relay.Var("b")
clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a),
relay.PatternVar(b)]), a + b)
x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
- assert_alpha_equal(dcpe(x), const(2))
+ tvm.ir.assert_structural_equal(dcpe(x), const(2))
if __name__ == '__main__':
diff --git a/tests/python/relay/test_pass_partition_graph.py
b/tests/python/relay/test_pass_partition_graph.py
index 1f37ab8..fc8dfb6 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -339,7 +339,7 @@ def test_extern_ccompiler_default_ops():
fused_mod = transform.FuseOps(2)(mod)
expected_mod = expected()
- assert relay.alpha_equal(fused_mod, expected_mod)
+ assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True)
x_data = np.random.rand(8, 8).astype('float32')
y_data = np.random.rand(8, 8).astype('float32')
@@ -427,7 +427,7 @@ def test_extern_dnnl():
mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func())
mod = transform.PartitionGraph()(mod)
- assert relay.alpha_equal(mod, expected())
+ assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
ref_mod = tvm.IRModule()
ref_mod["main"] = get_func()
@@ -561,7 +561,7 @@ def test_function_lifting():
partitioned = partition()
ref_mod = expected()
- assert relay.analysis.alpha_equal(partitioned, ref_mod)
+ assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_function_lifting_inline():
@@ -631,7 +631,7 @@ def test_function_lifting_inline():
partitioned = partition()
ref_mod = expected()
- assert relay.analysis.alpha_equal(partitioned, ref_mod)
+ assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_constant_propagation():
@@ -671,7 +671,7 @@ def test_constant_propagation():
mod = transform.PartitionGraph()(mod)
expected_mod = expected()
- assert relay.alpha_equal(mod, expected_mod)
+ assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True)
y_data = np.random.rand(8, 8).astype('float32')
np_add = ones + y_data
diff --git a/tests/python/relay/test_pass_qnn_legalize.py
b/tests/python/relay/test_pass_qnn_legalize.py
index b164821..e7980e7 100644
--- a/tests/python/relay/test_pass_qnn_legalize.py
+++ b/tests/python/relay/test_pass_qnn_legalize.py
@@ -31,7 +31,7 @@ def alpha_equal(x, y):
"""
x = x['main']
y = y['main']
- return analysis.alpha_equal(x, y) and analysis.structural_hash(x) ==
analysis.structural_hash(y)
+ return tvm.ir.structural_equal(x, y) and analysis.structural_hash(x) ==
analysis.structural_hash(y)
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
@@ -85,12 +85,12 @@ def test_qnn_legalize():
# Check that Relay Legalize does not change the graph.
a = run_opt_pass(a, relay.transform.Legalize())
b = run_opt_pass(before(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
# Check that QNN Legalize modifies the graph.
a = run_opt_pass(a, relay.qnn.transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())
- assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
def test_qnn_legalize_qnn_conv2d():
diff --git a/tests/python/relay/test_pass_remove_unused_functions.py
b/tests/python/relay/test_pass_remove_unused_functions.py
index 3381634..43b54e9 100644
--- a/tests/python/relay/test_pass_remove_unused_functions.py
+++ b/tests/python/relay/test_pass_remove_unused_functions.py
@@ -110,7 +110,7 @@ def test_call_globalvar_without_args():
mod = get_mod()
ref_mod = get_mod()
mod = relay.transform.RemoveUnusedFunctions()(mod)
- assert relay.alpha_equal(mod, ref_mod)
+ assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
if __name__ == '__main__':
diff --git a/tests/python/relay/test_pass_simplify_inference.py
b/tests/python/relay/test_pass_simplify_inference.py
index bb39893..3a8c90b 100644
--- a/tests/python/relay/test_pass_simplify_inference.py
+++ b/tests/python/relay/test_pass_simplify_inference.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from tvm.ir import IRModule
+from tvm.ir import IRModule, structural_equal
from tvm import relay as rly
from tvm.relay.transform import SimplifyInference
@@ -56,7 +56,7 @@ def test_simplify_batchnorm(dtype='float32'):
mod = simplify(mod)
y1 = mod["main"].body
- assert rly.analysis.graph_equal(y1, y2)
+ assert structural_equal(y1, y2, map_free_vars=True)
check(2, 1, 1)
check(4, 1, 1)
diff --git a/tests/python/relay/test_pass_to_a_normal_form.py
b/tests/python/relay/test_pass_to_a_normal_form.py
index 29818f8..d7babf3 100644
--- a/tests/python/relay/test_pass_to_a_normal_form.py
+++ b/tests/python/relay/test_pass_to_a_normal_form.py
@@ -18,7 +18,7 @@ import numpy as np
import tvm
from tvm import te
from tvm import relay
-from tvm.relay.analysis import alpha_equal, detect_feature
+from tvm.relay.analysis import detect_feature
from tvm.relay import op, create_executor, transform
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, count
diff --git a/tests/python/relay/test_type_functor.py
b/tests/python/relay/test_type_functor.py
index 9e023bc..b90a688 100644
--- a/tests/python/relay/test_type_functor.py
+++ b/tests/python/relay/test_type_functor.py
@@ -18,7 +18,6 @@ import tvm
from tvm import te
from tvm import relay
from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor
-from tvm.relay.analysis import assert_graph_equal
from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType,
TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall)
from tvm.relay.adt import TypeData
@@ -34,7 +33,8 @@ def check_visit(typ):
ev = TypeVisitor()
ev.visit(typ)
- assert_graph_equal(TypeMutator().visit(typ), typ)
+ tvm.ir.assert_structural_equal(TypeMutator().visit(typ), typ,
+ map_free_vars=True)
def test_type_var():
diff --git a/tests/python/unittest/test_ir_type.py
b/tests/python/unittest/test_ir_type.py
index f919f92..a0e7d2b 100644
--- a/tests/python/unittest/test_ir_type.py
+++ b/tests/python/unittest/test_ir_type.py
@@ -18,10 +18,9 @@
import tvm
def check_json_roundtrip(node):
- from tvm.relay.analysis import graph_equal
json_str = tvm.ir.save_json(node)
back = tvm.ir.load_json(json_str)
- assert graph_equal(back, node)
+ assert tvm.ir.structural_equal(back, node, map_free_vars=True)
def test_prim_type():