This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 919ae88  [REFACTOR][IR] alpha_equal to structural_equal (#5161)
919ae88 is described below

commit 919ae889638555b82de2d124d5f3e08d76bf789b
Author: Zhi <[email protected]>
AuthorDate: Sun Mar 29 09:58:58 2020 -0700

    [REFACTOR][IR] alpha_equal to structural_equal (#5161)
---
 include/tvm/ir/type.h                              |   4 +-
 include/tvm/relay/analysis.h                       |  55 --
 python/tvm/ir/type.py                              |   3 +-
 python/tvm/relay/__init__.py                       |   1 -
 python/tvm/relay/analysis/analysis.py              |  72 ---
 src/ir/module.cc                                   |  19 +-
 src/relay/analysis/alpha_equal.cc                  | 628 ---------------------
 src/relay/analysis/type_solver.cc                  |   7 +-
 src/relay/backend/compile_engine.h                 |   3 +-
 src/relay/backend/vm/lambda_lift.cc                |   4 +-
 src/relay/op/tensor/transform.cc                   |   7 +-
 src/relay/transforms/lazy_gradient_init.cc         |  22 +-
 src/relay/transforms/pattern_util.h                |   3 +-
 src/relay/transforms/to_cps.cc                     |   2 +-
 tests/cpp/relay_pass_alpha_equal.cc                |  67 ---
 tests/cpp/relay_pass_type_infer_test.cc            |   3 +-
 tests/cpp/relay_transform_sequential.cc            |   3 +-
 tests/python/frontend/caffe2/test_graph.py         |   3 +-
 tests/python/frontend/mxnet/test_graph.py          |   2 +-
 .../relay/test_analysis_extract_fused_functions.py |   2 +-
 tests/python/relay/test_annotate_target.py         |   2 +-
 tests/python/relay/test_call_graph.py              |   2 +-
 tests/python/relay/test_ir_bind.py                 |   4 +-
 tests/python/relay/test_ir_nodes.py                |   3 +-
 .../python/relay/test_ir_structural_equal_hash.py  |   2 +-
 tests/python/relay/test_ir_text_printer.py         |   6 +-
 tests/python/relay/test_op_level10.py              |   4 +-
 tests/python/relay/test_pass_alter_op_layout.py    |  46 +-
 tests/python/relay/test_pass_annotation.py         |  12 +-
 tests/python/relay/test_pass_canonicalize_cast.py  |   2 +-
 .../relay/test_pass_combine_parallel_conv2d.py     |   8 +-
 .../relay/test_pass_combine_parallel_dense.py      |   6 +-
 tests/python/relay/test_pass_convert_op_layout.py  |  22 +-
 .../relay/test_pass_dead_code_elimination.py       |   6 +-
 .../relay/test_pass_eliminate_common_subexpr.py    |   4 +-
 tests/python/relay/test_pass_eta_expand.py         |   6 +-
 tests/python/relay/test_pass_fold_constant.py      |  14 +-
 tests/python/relay/test_pass_fold_scale_axis.py    |  22 +-
 tests/python/relay/test_pass_fuse_ops.py           |  30 +-
 tests/python/relay/test_pass_gradient.py           |   4 +-
 tests/python/relay/test_pass_inline.py             |  32 +-
 tests/python/relay/test_pass_legalize.py           |   8 +-
 tests/python/relay/test_pass_manager.py            |   6 +-
 tests/python/relay/test_pass_merge_composite.py    |  23 +-
 tests/python/relay/test_pass_partial_eval.py       |   9 +-
 tests/python/relay/test_pass_partition_graph.py    |  10 +-
 tests/python/relay/test_pass_qnn_legalize.py       |   6 +-
 .../relay/test_pass_remove_unused_functions.py     |   2 +-
 tests/python/relay/test_pass_simplify_inference.py |   4 +-
 tests/python/relay/test_pass_to_a_normal_form.py   |   2 +-
 tests/python/relay/test_type_functor.py            |   4 +-
 tests/python/unittest/test_ir_type.py              |   3 +-
 52 files changed, 208 insertions(+), 1016 deletions(-)

diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h
index 6f6c66a..0e65758 100644
--- a/include/tvm/ir/type.h
+++ b/include/tvm/ir/type.h
@@ -498,7 +498,9 @@ class IncompleteTypeNode : public TypeNode {
   }
 
   bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) 
const {
-    return equal(kind, other->kind);
+    return
+        equal(kind, other->kind) &&
+        equal.FreeVarEqualImpl(this, other);
   }
 
   void SHashReduce(SHashReducer hash_reduce) const {
diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h
index fe8fae5..51eae5a 100644
--- a/include/tvm/relay/analysis.h
+++ b/include/tvm/relay/analysis.h
@@ -65,61 +65,6 @@ TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod);
 TVM_DLL bool ConstantCheck(const Expr& e);
 
 /*!
- * \brief Compare two expressions for structural equivalence.
- *
- * This comparison operator respects scoping and compares
- * expressions without regard to variable choice.
- *
- * For example: `let x = 1 in x` is equal to `let y = 1 in y`.
- *
- *   See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
- *   for more details.
- *
- *   \param e1 The left hand expression.
- *   \param e2 The right hand expression.
- *
- *   \return true if equal, otherwise false
- */
-TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);
-
-/*!
- * \brief Compare two types for structural equivalence.
- *
- * This comparison operator respects scoping and compares
- * expressions without regard to variable choice.
- *
- * For example: `forall s, Tensor[f32, s]` is equal to
- * `forall w, Tensor[f32, w]`.
- *
- * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
- * for more details.
- *
- * \param t1 The left hand type.
- * \param t2 The right hand type.
- *
- * \return true if equal, otherwise false
- */
-TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);
-
-/*!
- * \brief Compare two patterns for structural equivalence.
- *
- * This comparison operator respects scoping and compares
- * patterns without regard to variable choice.
- *
- * For example: `A(x, _, y)` is equal to `A(z, _, a)`.
- *
- * See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
- * for more details.
- *
- * \param t1 The left hand pattern.
- * \param t2 The right hand pattern.
- *
- * \return true if equal, otherwise false
- */
-TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2);
-
-/*!
  * \brief Check that each Var is only bound once.
  *
  * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice.
diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py
index a61c6e4..e980011 100644
--- a/python/tvm/ir/type.py
+++ b/python/tvm/ir/type.py
@@ -16,6 +16,7 @@
 # under the License.
 """Unified type system in the project."""
 from enum import IntEnum
+import tvm
 import tvm._ffi
 
 from .base import Node
@@ -26,7 +27,7 @@ class Type(Node):
     """The base class of all types."""
     def __eq__(self, other):
         """Compare two types for structural equivalence."""
-        return bool(_ffi_api.type_alpha_equal(self, other))
+        return bool(tvm.ir.structural_equal(self, other))
 
     def __ne__(self, other):
         return not self.__eq__(other)
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index 95545c8..1517cf9 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -33,7 +33,6 @@ from . import parser
 
 from . import transform
 from . import analysis
-from .analysis import alpha_equal
 from .build_module import build, create_executor, optimize
 from .transform import build_config
 from . import debug
diff --git a/python/tvm/relay/analysis/analysis.py 
b/python/tvm/relay/analysis/analysis.py
index 722f3b0..b09a40b 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -220,78 +220,6 @@ def all_type_vars(expr, mod=None):
     return _ffi_api.all_type_vars(expr, use_mod)
 
 
-def alpha_equal(lhs, rhs):
-    """Compare two Relay expr for structural equivalence (alpha equivalence).
-
-    Parameters
-    ----------
-    lhs : tvm.relay.Expr
-        One of the input Expression.
-
-    rhs : tvm.relay.Expr
-        One of the input Expression.
-
-    Returns
-    -------
-    result : bool
-        True iff lhs is alpha equal to rhs.
-    """
-    return bool(_ffi_api._alpha_equal(lhs, rhs))
-
-
-def assert_alpha_equal(lhs, rhs):
-    """Assert that two Relay expr is structurally equivalent. (alpha 
equivalence).
-
-    Parameters
-    ----------
-    lhs : tvm.relay.Expr
-        One of the input Expression.
-
-    rhs : tvm.relay.Expr
-        One of the input Expression.
-    """
-    _ffi_api._assert_alpha_equal(lhs, rhs)
-
-
-def graph_equal(lhs, rhs):
-    """Compare two Relay expr for data-flow equivalence.
-    The difference between this and alpha-equality is that
-    variables are not expected to match between lhs and rhs;
-    they are treated as sources and are mapped between each other.
-
-    Parameters
-    ----------
-    lhs : tvm.relay.Expr
-      One of the input Expression.
-
-    rhs : tvm.relay.Expr
-      One of the input Expression.
-
-    Returns
-    -------
-    result : bool
-      True iff lhs is data-flow equivalent to rhs.
-    """
-    return bool(_ffi_api._graph_equal(lhs, rhs))
-
-
-def assert_graph_equal(lhs, rhs):
-    """Compare two Relay expr for data-flow equivalence.
-    The difference between this and alpha-equality is that
-    variables are not expected to match between lhs and rhs;
-    they are treated as sources and are mapped between each other.
-
-    Parameters
-    ----------
-    lhs : tvm.relay.Expr
-      One of the input Expression.
-
-    rhs : tvm.relay.Expr
-      One of the input Expression.
-    """
-    _ffi_api._assert_graph_equal(lhs, rhs)
-
-
 def collect_device_info(expr):
     """Collect the device allocation map for the given expression. The device
     ids are propagated from the `device_copy` operators.
diff --git a/src/ir/module.cc b/src/ir/module.cc
index de09314..c7474de 100644
--- a/src/ir/module.cc
+++ b/src/ir/module.cc
@@ -23,6 +23,7 @@
  */
 #include <tvm/runtime/registry.h>
 #include <tvm/ir/module.h>
+#include <tvm/node/structural_equal.h>
 // NOTE: reverse dependency on relay.
 // These dependencies do not happen at the interface-level,
 // and are only used in minimum cases where they are clearly marked.
@@ -194,12 +195,11 @@ relay::Function RunTypeCheck(const IRModule& mod,
         << AsText(func, false)
         << std::endl;
   }
-  func =
-      relay::Function(concat(func->params, fv),
-                                func->body,
-                                func->ret_type,
-                                concat(func->type_params, ftv),
-                                func->attrs);
+  func = relay::Function(concat(func->params, fv),
+                         func->body,
+                         func->ret_type,
+                         concat(func->type_params, ftv),
+                         func->attrs);
   // Type check the item before we add it to the module.
   relay::Function checked_func = InferType(func, mod, var);
   return checked_func;
@@ -222,7 +222,7 @@ void IRModuleNode::Add(const GlobalVar& var,
     CHECK(update)
         << "Already have definition for " << var->name_hint;
     auto old_type = functions[var]->checked_type();
-    CHECK(relay::AlphaEqual(type, old_type))
+    CHECK(tvm::StructuralEqual()(type, old_type))
         << "Module#update changes type, not possible in this mode.";
   }
   var->checked_type_ = type;
@@ -353,9 +353,8 @@ IRModule IRModule::FromExpr(
   if (auto* func_node = expr.as<BaseFuncNode>()) {
     func = GetRef<BaseFunc>(func_node);
   } else {
-    func = relay::Function(
-        relay::FreeVars(expr), expr, Type(),
-        relay::FreeTypeVars(expr, mod), {});
+    func = relay::Function(relay::FreeVars(expr), expr, Type(),
+                           relay::FreeTypeVars(expr, mod), {});
   }
   auto main_gv = GlobalVar("main");
   mod->Add(main_gv, func);
diff --git a/src/relay/analysis/alpha_equal.cc 
b/src/relay/analysis/alpha_equal.cc
deleted file mode 100644
index 28c7681..0000000
--- a/src/relay/analysis/alpha_equal.cc
+++ /dev/null
@@ -1,628 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file src/relay/analysis/alpha_equal.cc
- * \brief Alpha equality check by deep comparing two nodes.
- */
-#include <tvm/ir/type_functor.h>
-#include <tvm/tir/ir_pass.h>
-#include <tvm/relay/expr_functor.h>
-#include <tvm/relay/pattern_functor.h>
-#include <tvm/runtime/ndarray.h>
-#include <tvm/relay/analysis.h>
-#include <tvm/relay/op_attr_types.h>
-#include <tvm/relay/attrs/nn.h>
-#include "../../ir/attr_functor.h"
-
-
-namespace tvm {
-namespace relay {
-
-// Alpha Equal handler for Relay.
-class AlphaEqualHandler:
-      public AttrsEqualHandler,
-      public TypeFunctor<bool(const Type&, const Type&)>,
-      public ExprFunctor<bool(const Expr&, const Expr&)>,
-      public PatternFunctor<bool(const Pattern&, const Pattern&)> {
- public:
-  explicit AlphaEqualHandler(bool map_free_var, bool assert_mode)
-    : map_free_var_(map_free_var), assert_mode_(assert_mode) { }
-
-  /*!
-   * Check equality of two nodes.
-   * \param lhs The left hand operand.
-   * \param rhs The right hand operand.
-   * \return The comparison result.
-   */
-  bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
-    return VisitAttr(lhs, rhs);
-  }
-  /*!
-   * Check equality of two attributes.
-   * \param lhs The left hand operand.
-   * \param rhs The right hand operand.
-   * \return The comparison result.
-   */
-  bool AttrEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
-    auto compute = [&]() {
-      return VisitAttr(lhs, rhs);
-    };
-    return Compare(compute(), lhs, rhs);
-  }
-  /*!
-   * Check equality of two types.
-   * \param lhs The left hand operand.
-   * \param rhs The right hand operand.
-   * \return the comparison result.
-   */
-  bool TypeEqual(const Type& lhs, const Type& rhs) {
-    auto compute = [&]() {
-      if (lhs.same_as(rhs)) return true;
-      if (!lhs.defined() || !rhs.defined()) return false;
-      return this->VisitType(lhs, rhs);
-    };
-    return Compare(compute(), lhs, rhs);
-  }
-
-  bool Compare(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
-    if (assert_mode_) {
-      CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" << 
AsText(rhs, true);
-    }
-    return result;
-  }
-  /*!
-   * Check equality of two expressions.
-   *
-   * \note We run graph structural equality checking when comparing two Exprs.
-   *   This means that AlphaEqualHandler can only be used once for each pair.
-   *   The equality checker checks data-flow equvalence of the Expr DAG.
-   *   This function also runs faster as it memomizes equal_map.
-   *
-   * \param lhs The left hand operand.
-   * \param rhs The right hand operand.
-   * \return The comparison result.
-   */
-  bool ExprEqual(const Expr& lhs, const Expr& rhs) {
-    auto compute = [&]() {
-      if (lhs.same_as(rhs)) return true;
-      if (!lhs.defined() || !rhs.defined()) return false;
-      auto it = equal_map_.find(lhs);
-      if (it != equal_map_.end()) {
-        return it->second.same_as(rhs);
-      }
-      if (this->VisitExpr(lhs, rhs)) {
-        equal_map_[lhs] = rhs;
-        return true;
-      } else {
-        return false;
-      }
-    };
-    return Compare(compute(), lhs, rhs);
-  }
-
- protected:
-  // So that the new definition of equality in relay can be handled directly.
-  // Specifically, if a DictAttr contains a value defined by a relay AST.
-  // We want to able to recursively check the equality in the attr defined by 
the relay AST.
-  bool VisitAttr(const ObjectRef& lhs, const ObjectRef& rhs) final {
-    if (lhs.same_as(rhs)) return true;
-    if (!lhs.defined() && rhs.defined()) return false;
-    if (!rhs.defined() && lhs.defined()) return false;
-    if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
-      if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return 
false;
-      return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
-    }
-    if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
-      if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return 
false;
-      return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
-    }
-    if (const auto lhsm = lhs.as<IRModuleNode>()) {
-      auto rhsm = rhs.as<IRModuleNode>();
-      if (!rhsm) return false;
-      if (lhsm->functions.size() != rhsm->functions.size()) return false;
-      for (const auto& p : lhsm->functions) {
-        if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
-      }
-      if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) 
return false;
-      for (const auto& p : lhsm->type_definitions) {
-        if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
-            !Equal(p.second, rhsm->LookupTypeDef(p.first->name_hint))) {
-          return false;
-        }
-      }
-      return true;
-    }
-    // Fall back to the object equal case.
-    return AttrsEqualHandler::VisitAttr(lhs, rhs);
-  }
-  /*!
-   * \brief Check if data type equals each other.
-   * \param lhs The left hand operand.
-   * \param rhs The right hand operand.
-   * \return The compare result.
-   */
-  bool DataTypeEqual(const DataType& lhs, const DataType& rhs) {
-    return lhs == rhs;
-  }
-  /*!
-   * \brief Check Equality of leaf node of the graph.
-   *  if map_free_var_ is set to true, try to map via equal node.
-   * \param lhs The left hand operand.
-   * \param rhs The right hand operand.
-   * \return The compare result.
-   */
-  bool LeafObjectEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
-    if (lhs.same_as(rhs)) return true;
-    auto it = equal_map_.find(lhs);
-    if (it != equal_map_.end()) {
-      return it->second.same_as(rhs);
-    } else {
-      if (map_free_var_) {
-        if (lhs->type_index() != rhs->type_index()) return false;
-        equal_map_[lhs] = rhs;
-        return true;
-      } else {
-        return false;
-      }
-    }
-  }
-  using AttrsEqualHandler::VisitAttr_;
-  bool VisitAttr_(const tvm::tir::VarNode* lhs, const ObjectRef& other) final {
-    return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
-  }
-
-  // Type equality
-  bool VisitType_(const TensorTypeNode* lhs, const Type& other) final {
-    if (const TensorTypeNode* rhs = other.as<TensorTypeNode>()) {
-      return (lhs->dtype == rhs->dtype &&
-              AttrEqual(lhs->shape, rhs->shape));
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitType_(const IncompleteTypeNode* lhs, const Type& other) final {
-    return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
-  }
-
-  bool VisitType_(const PrimTypeNode* lhs, const Type& other) final {
-    if (const PrimTypeNode* rhs = other.as<PrimTypeNode>()) {
-      return lhs->dtype == rhs->dtype;
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitType_(const PointerTypeNode* lhs, const Type& other) final {
-    if (const PointerTypeNode* rhs = other.as<PointerTypeNode>()) {
-      return TypeEqual(lhs->element_type, rhs->element_type);
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitType_(const TypeVarNode* lhs, const Type& other) final {
-    if (const TypeVarNode* rhs = other.as<TypeVarNode>()) {
-      if (lhs->kind != rhs->kind) return false;
-      return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitType_(const FuncTypeNode* lhs, const Type& other) final {
-    if (const FuncTypeNode* rhs = other.as<FuncTypeNode>()) {
-      if (lhs->arg_types.size() != rhs->arg_types.size()) return false;
-      if (lhs->type_params.size() != rhs->type_params.size()) return false;
-      if (lhs->type_constraints.size() != rhs->type_constraints.size()) return 
false;
-      for (size_t i = 0; i < lhs->type_params.size(); ++i) {
-        if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) {
-          return false;
-        }
-        equal_map_[lhs->type_params[i]] = rhs->type_params[i];
-      }
-      for (size_t i = 0; i < lhs->arg_types.size(); i++) {
-        if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
-      }
-      if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
-      for (size_t i = 0; i < lhs->type_constraints.size(); i++) {
-        if (!TypeEqual(lhs->type_constraints[i],
-                       rhs->type_constraints[i])) {
-          return false;
-        }
-      }
-      return true;
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitType_(const TypeRelationNode* lhs, const Type& other) final {
-    if (const TypeRelationNode* rhs = other.as<TypeRelationNode>()) {
-      if (lhs->func->name != rhs->func->name) return false;
-      if (lhs->num_inputs != rhs->num_inputs) return false;
-      if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false;
-      if (lhs->args.size() != rhs->args.size()) return false;
-      for (size_t i = 0; i < lhs->args.size(); ++i) {
-        if (!TypeEqual(lhs->args[i], rhs->args[i])) return false;
-      }
-      return true;
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitType_(const TupleTypeNode* lhs, const Type& other) final {
-    if (const TupleTypeNode* rhs = other.as<TupleTypeNode>()) {
-      if (lhs->fields.size() != rhs->fields.size()) return false;
-      for (size_t i = 0; i < lhs->fields.size(); ++i) {
-        if (!TypeEqual(lhs->fields[i], rhs->fields[i])) return false;
-      }
-      return true;
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitType_(const RelayRefTypeNode* lhs, const Type& other) final {
-    if (const RelayRefTypeNode* rhs = other.as<RelayRefTypeNode>()) {
-      return TypeEqual(lhs->value, rhs->value);
-    }
-    return false;
-  }
-
-  bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final {
-    return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
-  }
-
-  bool VisitType_(const TypeCallNode* lhs, const Type& other) final {
-    const TypeCallNode* rhs = other.as<TypeCallNode>();
-    if (rhs == nullptr
-        || lhs->args.size() != rhs->args.size()
-        || !TypeEqual(lhs->func, rhs->func)) {
-      return false;
-    }
-
-    for (size_t i = 0; i < lhs->args.size(); ++i) {
-      if (!TypeEqual(lhs->args[i], rhs->args[i])) {
-        return false;
-      }
-    }
-    return true;
-  }
-
-  bool VisitType_(const TypeDataNode* lhs, const Type& other) final {
-    const TypeDataNode* rhs = other.as<TypeDataNode>();
-    if (rhs == nullptr
-        || lhs->type_vars.size() != rhs->type_vars.size()
-        || !TypeEqual(lhs->header, rhs->header)) {
-      return false;
-    }
-    for (size_t i = 0; i < lhs->type_vars.size(); ++i) {
-      if (!TypeEqual(lhs->type_vars[i], rhs->type_vars[i])) {
-        return false;
-      }
-    }
-    for (size_t i = 0; i < lhs->constructors.size(); ++i) {
-      if (!ExprEqual(lhs->constructors[i], rhs->constructors[i])) {
-        return false;
-      }
-    }
-    return true;
-  }
-
-  // Expr equal checking.
-  bool NDArrayEqual(const runtime::NDArray& lhs,
-                    const runtime::NDArray& rhs) {
-    if (lhs.defined() != rhs.defined()) {
-      return false;
-    } else if (lhs.same_as(rhs)) {
-      return true;
-    } else {
-      auto ldt = lhs->dtype;
-      auto rdt = rhs->dtype;
-      CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
-      CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor";
-      if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == 
rdt.bits) {
-        size_t data_size = runtime::GetDataSize(*lhs.operator->());
-        return std::memcmp(lhs->data, rhs->data, data_size) == 0;
-      } else {
-        return false;
-      }
-    }
-  }
-  // merge declaration of two variables together.
-  bool MergeVarDecl(const Var& lhs, const Var& rhs) {
-    if (lhs.same_as(rhs)) return true;
-    if (!lhs.defined() || !rhs.defined()) return false;
-    if (!TypeEqual(lhs->type_annotation,
-                   rhs->type_annotation)) return false;
-    CHECK(!equal_map_.count(lhs))
-        << "Duplicated declaration of variable " <<  lhs;
-    equal_map_[lhs] = rhs;
-    return true;
-  }
-
-  bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
-    // This function will only be triggered if we are matching free variables.
-    if (const VarNode* rhs = other.as<VarNode>()) {
-      if (lhs->name_hint() != rhs->name_hint()) return false;
-      if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
-      return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final {
-    if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) {
-      // use name equality for global var for now.
-      return lhs->name_hint == rhs->name_hint;
-    }
-    return false;
-  }
-
-  bool VisitExpr_(const TupleNode* lhs, const Expr& other) final {
-    if (const TupleNode* rhs = other.as<TupleNode>()) {
-      if (lhs->fields.size() != rhs->fields.size()) return false;
-      for (size_t i = 0; i < lhs->fields.size(); ++i) {
-        if (!ExprEqual(lhs->fields[i], rhs->fields[i])) return false;
-      }
-      return true;
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitExpr_(const FunctionNode* lhs, const Expr& other) final {
-    if (const FunctionNode* rhs = other.as<FunctionNode>()) {
-      if (lhs->params.size() != rhs->params.size()) return false;
-      if (lhs->type_params.size() != rhs->type_params.size()) return false;
-      // map type parameter to be the same
-      for (size_t i = 0; i < lhs->type_params.size(); ++i) {
-        if (lhs->type_params[i]->kind != rhs->type_params[i]->kind) return 
false;
-        equal_map_[lhs->type_params[i]] = rhs->type_params[i];
-      }
-      // check parameter type annotations
-      for (size_t i = 0; i < lhs->params.size(); ++i) {
-        if (!MergeVarDecl(lhs->params[i], rhs->params[i])) return false;
-      }
-      // check return types.
-      if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
-      if (!AttrEqual(lhs->attrs, rhs->attrs)) return false;
-      return ExprEqual(lhs->body, rhs->body);
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitExpr_(const CallNode* lhs, const Expr& other) final {
-    if (const CallNode* rhs = other.as<CallNode>()) {
-      if (!ExprEqual(lhs->op, rhs->op)) return false;
-      if (lhs->args.size() != rhs->args.size()) return false;
-      // skip type_args check for primitive ops.
-      bool is_primitive = IsPrimitiveOp(lhs->op);
-      if (!is_primitive) {
-        if (lhs->type_args.size() != rhs->type_args.size()) {
-          return false;
-        }
-      }
-      for (size_t i = 0; i < lhs->args.size(); ++i) {
-        if (!ExprEqual(lhs->args[i], rhs->args[i])) {
-          return false;
-        }
-      }
-
-      if (!is_primitive) {
-        for (size_t i = 0; i < lhs->type_args.size(); ++i) {
-          if (!TypeEqual(lhs->type_args[i], rhs->type_args[i])) return false;
-        }
-      }
-      return AttrEqual(lhs->attrs, rhs->attrs);
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitExpr_(const LetNode* lhs, const Expr& other) final {
-    if (const LetNode* rhs = other.as<LetNode>()) {
-      if (!MergeVarDecl(lhs->var, rhs->var)) return false;
-      if (!ExprEqual(lhs->value, rhs->value)) return false;
-      return ExprEqual(lhs->body, rhs->body);
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitExpr_(const IfNode* lhs, const Expr& other) final {
-    if (const IfNode* rhs = other.as<IfNode>()) {
-      return ExprEqual(lhs->cond, rhs->cond) &&
-          ExprEqual(lhs->true_branch, rhs->true_branch) &&
-          ExprEqual(lhs->false_branch, rhs->false_branch);
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitExpr_(const OpNode* lhs, const Expr& other) final {
-    return lhs == other.get();
-  }
-
-  bool VisitExpr_(const ConstantNode* lhs, const Expr& other) final {
-    if (const ConstantNode* rhs = other.as<ConstantNode>()) {
-      return NDArrayEqual(lhs->data, rhs->data);
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitExpr_(const TupleGetItemNode* lhs, const Expr& other) final {
-    if (const TupleGetItemNode* rhs = other.as<TupleGetItemNode>()) {
-      return ExprEqual(lhs->tuple, rhs->tuple) && lhs->index == rhs->index;
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitExpr_(const RefCreateNode* lhs, const Expr& other) final {
-    if (const RefCreateNode* rhs = other.as<RefCreateNode>()) {
-      return ExprEqual(lhs->value, rhs->value);
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitExpr_(const RefReadNode* lhs, const Expr& other) final {
-    if (const RefReadNode* rhs = other.as<RefReadNode>()) {
-      return ExprEqual(lhs->ref, rhs->ref);
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitExpr_(const RefWriteNode* lhs, const Expr& other) final {
-    if (const RefWriteNode* rhs = other.as<RefWriteNode>()) {
-      return ExprEqual(lhs->ref, rhs->ref) && ExprEqual(lhs->value, 
rhs->value);
-    } else {
-      return false;
-    }
-  }
-
-  bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final {
-    if (const ConstructorNode* rhs = other.as<ConstructorNode>()) {
-      return lhs->name_hint == rhs->name_hint;
-    }
-    return false;
-  }
-
-  bool ClauseEqual(const Clause& lhs, const Clause& rhs) {
-    return PatternEqual(lhs->lhs, rhs->lhs) && ExprEqual(lhs->rhs, rhs->rhs);
-  }
-
-  bool PatternEqual(const Pattern& lhs, const Pattern& rhs) {
-    return Compare(VisitPattern(lhs, rhs), lhs, rhs);
-  }
-
-  bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) 
final {
-    return other.as<PatternWildcardNode>();
-  }
-
-  bool VisitPattern_(const PatternVarNode* lhs, const Pattern& other) final {
-    if (const auto* rhs = other.as<PatternVarNode>()) {
-      return MergeVarDecl(lhs->var, rhs->var);
-    }
-    return false;
-  }
-
-  bool VisitPattern_(const PatternConstructorNode* lhs, const Pattern& other) 
final {
-    const auto* rhs = other.as<PatternConstructorNode>();
-    if (rhs == nullptr
-        || !ExprEqual(lhs->constructor, rhs->constructor)
-        || lhs->patterns.size() != rhs->patterns.size()) {
-      return false;
-    }
-
-    for (size_t i = 0; i < lhs->patterns.size(); i++) {
-      if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
-        return false;
-      }
-    }
-    return true;
-  }
-
-  bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final {
-    const auto* rhs = other.as<PatternTupleNode>();
-    if (rhs == nullptr
-        || lhs->patterns.size() != rhs->patterns.size()) {
-      return false;
-    }
-
-    for (size_t i = 0; i < lhs->patterns.size(); i++) {
-      if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
-        return false;
-      }
-    }
-    return true;
-  }
-
-  bool VisitExpr_(const MatchNode* lhs, const Expr& other) final {
-    const MatchNode* rhs = other.as<MatchNode>();
-
-    if (rhs == nullptr
-        || !ExprEqual(lhs->data, rhs->data)
-        || lhs->clauses.size() != rhs->clauses.size()
-        || lhs->complete != rhs->complete) {
-      return false;
-    }
-
-    for (size_t i = 0; i < lhs->clauses.size(); ++i) {
-      if (!ClauseEqual(lhs->clauses[i], rhs->clauses[i])) {
-        return false;
-      }
-    }
-    return true;
-  }
-
- private:
-  // whether to map open terms.
-  bool map_free_var_;
-  // if in assert mode, must return true, and will throw error otherwise.
-  bool assert_mode_;
-  // renaming of NodeRef to indicate two nodes equals to each other
-  std::unordered_map<ObjectRef, ObjectRef, ObjectHash, ObjectEqual> equal_map_;
-};
-
-bool AlphaEqual(const Type& lhs, const Type& rhs) {
-  return AlphaEqualHandler(false, false).TypeEqual(lhs, rhs);
-}
-
-bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
-  return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
-}
-
-TVM_REGISTER_GLOBAL("relay.analysis._alpha_equal")
-.set_body_typed([](ObjectRef a, ObjectRef b) {
-  return AlphaEqualHandler(false, false).Equal(a, b);
-});
-
-TVM_REGISTER_GLOBAL("ir.type_alpha_equal")
-.set_body_typed([](Type a, Type b) {
-  return AlphaEqual(a, b);
-});
-
-TVM_REGISTER_GLOBAL("relay.analysis._assert_alpha_equal")
-.set_body_typed([](ObjectRef a, ObjectRef b) {
-  bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
-  CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are 
not alpha equal";
-});
-
-TVM_REGISTER_GLOBAL("relay.analysis._graph_equal")
-.set_body_typed([](ObjectRef a, ObjectRef b) {
-  return AlphaEqualHandler(true, false).Equal(a, b);
-});
-
-TVM_REGISTER_GLOBAL("relay.analysis._assert_graph_equal")
-.set_body_typed([](ObjectRef a, ObjectRef b) {
-  bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
-  CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are 
not graph equal";
-});
-
-}  // namespace relay
-}  // namespace tvm
diff --git a/src/relay/analysis/type_solver.cc 
b/src/relay/analysis/type_solver.cc
index c39df9d..650403c 100644
--- a/src/relay/analysis/type_solver.cc
+++ b/src/relay/analysis/type_solver.cc
@@ -21,6 +21,7 @@
  * \file type_solver.cc
  * \brief Type solver implementations.
  */
+#include <tvm/node/structural_equal.h>
 #include <tvm/ir/type_functor.h>
 #include <tvm/tir/op.h>
 #include <string>
@@ -151,11 +152,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const 
Type&, const Type&)> {
     return rc.Check(t);
   }
 
-  // default: unify only if alpha-equal
+  // default: unify only if structural-equal
   Type VisitTypeDefault_(const Object* op, const Type& tn) final {
     ObjectRef nr = GetRef<ObjectRef>(op);
     Type t1 = GetRef<Type>(nr.as<tvm::relay::TypeNode>());
-    if (!AlphaEqual(t1, tn)) {
+    if (!tvm::StructuralEqual()(t1, tn)) {
       return Type(nullptr);
     }
     return t1;
@@ -216,7 +217,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const 
Type&, const Type&)> {
     auto tt1 = GetRef<TensorType>(op);
     auto tt2 = GetRef<TensorType>(tt_node);
 
-    if (AlphaEqual(tt1, tt2)) {
+    if (tvm::StructuralEqual()(tt1, tt2)) {
       return std::move(tt1);
     }
 
diff --git a/src/relay/backend/compile_engine.h 
b/src/relay/backend/compile_engine.h
index 098211e..eec2bd3 100644
--- a/src/relay/backend/compile_engine.h
+++ b/src/relay/backend/compile_engine.h
@@ -25,6 +25,7 @@
 #ifndef TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
 #define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_
 
+#include <tvm/node/structural_equal.h>
 #include <tvm/tir/lowered_func.h>
 #include <tvm/runtime/module.h>
 #include <tvm/relay/analysis.h>
@@ -268,7 +269,7 @@ inline bool CCacheKeyNode::Equal(
     const CCacheKeyNode* other) const {
   if (Hash() != other->Hash()) return false;
   return this->target->str() == other->target->str() &&
-      AlphaEqual(this->source_func, other->source_func);
+      tvm::StructuralEqual()(this->source_func, other->source_func);
 }
 
 }  // namespace relay
diff --git a/src/relay/backend/vm/lambda_lift.cc 
b/src/relay/backend/vm/lambda_lift.cc
index 7e7622c..398760f 100644
--- a/src/relay/backend/vm/lambda_lift.cc
+++ b/src/relay/backend/vm/lambda_lift.cc
@@ -22,6 +22,7 @@
  * \brief Lift all local functions into global functions.
  */
 
+#include <tvm/node/structural_equal.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/support/logging.h>
@@ -161,7 +162,8 @@ class LambdaLifter : public ExprMutator {
 
     if (module_->ContainGlobalVar(name)) {
       const auto existing_func = module_->Lookup(name);
-      CHECK(AlphaEqual(lifted_func, existing_func)) << "lifted function hash 
collision";
+      CHECK(tvm::StructuralEqual()(lifted_func, existing_func))
+        << "lifted function hash collision";
       // If an identical function already exists, use its global var.
       global = module_->GetGlobalVar(name);
     } else {
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 3d03b4a..87b4602 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -2142,7 +2142,12 @@ Expr MakeSplit(Expr data,
 TVM_REGISTER_GLOBAL("relay.op._make.split")
 .set_body([](const TVMArgs& args, TVMRetValue* rv) {
     if (args.type_codes[1] == kDLInt) {
-      *rv = MakeSplit(args[0], tir::make_const(DataType::Int(64), 
int64_t(args[1])), args[2]);
+      // Note: we change it from Int(64) to Int(32) for now as
+      // combine_parallel_dense will transform the graph with Int(32).
+      // More invetigation is needs to check which one we should use.
+      *rv = MakeSplit(args[0],
+                      tir::make_const(DataType::Int(32), 
static_cast<int>(args[1])),
+                      args[2]);
     } else {
       *rv = MakeSplit(args[0], args[1], args[2]);
     }
diff --git a/src/relay/transforms/lazy_gradient_init.cc 
b/src/relay/transforms/lazy_gradient_init.cc
index ba6ca05..e6248f1 100644
--- a/src/relay/transforms/lazy_gradient_init.cc
+++ b/src/relay/transforms/lazy_gradient_init.cc
@@ -59,6 +59,7 @@
  * Thus, it is necessary to wrap this outer function so that the input/output 
types remain the same
  */
 
+#include <tvm/node/structural_equal.h>
 #include <tvm/relay/analysis.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/ir/type_functor.h>
@@ -93,7 +94,7 @@ class InputVisitor: public ExprFunctor<Expr(const Expr&, 
const Type&)> {
   Expr WrapExpr(const Expr expr, const Type& type) {
     if (type.as<TensorTypeNode>()) {
       return Call(module_->GetConstructor("GradCell", "Raw"),
-                          {expr}, Attrs(), {type});
+                  {expr}, Attrs(), {type});
     } else if (auto* type_anno = type.as<TupleTypeNode>()) {
       tvm::Array<Expr> fields;
       for (size_t i = 0; i < type_anno->fields.size(); i++) {
@@ -185,7 +186,7 @@ class LazyGradientInitializer: public ExprMutator, public 
TypeMutator {
 
   Expr VisitExpr_(const ConstantNode* op) final {
     return Call(module_->GetConstructor("GradCell", "Raw"),
-                          {GetRef<Constant>(op)}, Attrs(), 
{op->checked_type()});
+                {GetRef<Constant>(op)}, Attrs(), {op->checked_type()});
   }
 
   Expr VisitExpr_(const CallNode* call_node) final {
@@ -207,26 +208,25 @@ class LazyGradientInitializer: public ExprMutator, public 
TypeMutator {
         // call appropriate GradCell constructor
         std::string constructor_name = op_expr == Op::Get("ones") ? "One" : 
"Zero";
         return Call(module_->GetConstructor("GradCell", constructor_name),
-                              {func}, Attrs(), {call_node->checked_type()});
+                    {func}, Attrs(), {call_node->checked_type()});
       }
 
       if (op_expr == Op::Get("ones_like") || op_expr == Op::Get("zeros_like")) 
{
         // ones_like and zeros_like need TensorType input
         Expr result = CallPrimitiveOp(call_node);
         // fn() -> T, function returns result of operation
-        Expr func = Function({}, result,
-                              {call_node->checked_type()}, Array<TypeVar>());
+        Expr func = Function({}, result, {call_node->checked_type()}, 
Array<TypeVar>());
         // call appropriate GradCell constructor
         std::string constructor_name = op_expr == Op::Get("ones_like") ? "One" 
: "Zero";
         return Call(module_->GetConstructor("GradCell", "One"),
-                              {func}, Attrs(), {call_node->checked_type()});
+                    {func}, Attrs(), {call_node->checked_type()});
       }
 
       // handle all other ops
       Expr result = CallPrimitiveOp(call_node);
       // wrap result with Raw constructor
       return Call(module_->GetConstructor("GradCell", "Raw"), {result},
-                            Attrs(), {call_node->checked_type()});
+                  Attrs(), {call_node->checked_type()});
     }
     // not an op
     return ExprMutator::VisitExpr_(call_node);
@@ -253,10 +253,11 @@ class LazyGradientInitializer: public ExprMutator, public 
TypeMutator {
   Expr CallGradCellFunction(const CallNode* call_node, GlobalVar 
overloaded_op) {
     // can only use overloaded functions if 2 arguments of same type
     if (call_node->args.size() != 2 ||
-          !AlphaEqual(call_node->args[0]->checked_type(), 
call_node->args[1]->checked_type())) {
+        !tvm::StructuralEqual()(call_node->args[0]->checked_type(),
+                                call_node->args[1]->checked_type())) {
       Expr result = CallPrimitiveOp(call_node);
       return Call(module_->GetConstructor("GradCell", "Raw"), {result},
-                            Attrs(), {call_node->checked_type()});
+                  Attrs(), {call_node->checked_type()});
     }
 
     tvm::Array<Expr> args;
@@ -266,8 +267,7 @@ class LazyGradientInitializer: public ExprMutator, public 
TypeMutator {
                               Var("rhs", paramType)};
     // use primitive op in this case
     Expr callOp = Call(call_node->op, {params[0], params[1]});
-    Expr func = Function(params, callOp, paramType,
-                                            Array<TypeVar>());
+    Expr func = Function(params, callOp, paramType, Array<TypeVar>());
 
     // pass "fallback" function and tensors as arguments
     args.push_back(func);
diff --git a/src/relay/transforms/pattern_util.h 
b/src/relay/transforms/pattern_util.h
index e86fcdc..8ce42a2 100644
--- a/src/relay/transforms/pattern_util.h
+++ b/src/relay/transforms/pattern_util.h
@@ -27,6 +27,7 @@
 #define TVM_RELAY_TRANSFORMS_PATTERN_UTIL_H_
 
 #include <builtin_fp16.h>
+#include <tvm/node/structural_equal.h>
 #include <tvm/tir/data_layout.h>
 #include <tvm/relay/op.h>
 #include <tvm/relay/expr.h>
@@ -300,7 +301,7 @@ inline bool IsEqualScalar(const Expr& a, const Expr& b) {
   if (!constant_a || !constant_b || !constant_a->is_scalar() || 
!constant_b->is_scalar()) {
     return false;
   }
-  return AlphaEqual(a, b);
+  return tvm::StructuralEqual()(a, b);
 }
 
 inline Expr GetField(Expr t, size_t i) {
diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc
index 1039a1b..e6c8392 100644
--- a/src/relay/transforms/to_cps.cc
+++ b/src/relay/transforms/to_cps.cc
@@ -353,7 +353,7 @@ Function UnCPS(const Function& f) {
   auto answer_type = new_type_params.back();
   new_type_params.pop_back();
   // TODO(@M.K.): make alphaequal work on free term
-  // CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type)));
+  // CHECK(tvm::StructuralEqual()(cont_type, Arrow(new_ret_type, 
answer_type)));
   auto x = Var("x", new_ret_type);
   auto cont = Function({x}, x, new_ret_type, {}, {});
   tvm::Array<Expr> args;
diff --git a/tests/cpp/relay_pass_alpha_equal.cc 
b/tests/cpp/relay_pass_alpha_equal.cc
deleted file mode 100644
index 0207fca..0000000
--- a/tests/cpp/relay_pass_alpha_equal.cc
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <gtest/gtest.h>
-#include <tvm/te/operation.h>
-#include <tvm/relay/expr.h>
-#include <tvm/relay/type.h>
-#include <tvm/relay/analysis.h>
-#include <tvm/relay/transform.h>
-
-using namespace tvm;
-
-class TestAlphaEquals {
-  runtime::PackedFunc *_packed_func;
- public:
-  TestAlphaEquals(const char* func_name) {
-    _packed_func = new runtime::PackedFunc();
-    TVMFuncGetGlobal(func_name, 
reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
-  }
-
-  void UpdatePackedFunc(const char* func_name) {
-    TVMFuncGetGlobal(func_name, 
reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
-  }
-
-  bool operator()(ObjectRef input_1, ObjectRef input_2) {
-    TVMRetValue rv;
-    std::vector<TVMValue> values(2);
-    std::vector<int> codes(2);
-    runtime::TVMArgsSetter setter(values.data(), codes.data());
-    setter(0, input_1);
-    setter(1, input_2);
-    _packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv);
-    return bool(rv);
-  };
-
-};
-
-TEST(Relay, AlphaTestEmptyTypeNodes) {
-  auto x = TypeVar("x", kTypeData);
-  auto y = TypeVar();
-  EXPECT_FALSE(relay::AlphaEqual(x, y));
-
-  TestAlphaEquals test_equals("relay._make._alpha_equal");
-  EXPECT_FALSE(test_equals(x, y));
-}
-
-int main(int argc, char ** argv) {
-  testing::InitGoogleTest(&argc, argv);
-  testing::FLAGS_gtest_death_test_style = "threadsafe";
-  return RUN_ALL_TESTS();
-}
diff --git a/tests/cpp/relay_pass_type_infer_test.cc 
b/tests/cpp/relay_pass_type_infer_test.cc
index f951a8f..3c41691 100644
--- a/tests/cpp/relay_pass_type_infer_test.cc
+++ b/tests/cpp/relay_pass_type_infer_test.cc
@@ -18,6 +18,7 @@
  */
 
 #include <gtest/gtest.h>
+#include <tvm/node/structural_equal.h>
 #include <tvm/te/operation.h>
 #include <tvm/relay/expr.h>
 #include <tvm/relay/type.h>
@@ -38,7 +39,7 @@ TEST(Relay, SelfReference) {
   auto type_fx = mod->Lookup("main");
 
   auto expected = relay::FuncType(tvm::Array<relay::Type>{ tensor_type }, 
tensor_type, {}, {});
-  CHECK(relay::AlphaEqual(type_fx->checked_type(), expected));
+  CHECK(tvm::StructuralEqual()(type_fx->checked_type(), expected));
 }
 
 int main(int argc, char ** argv) {
diff --git a/tests/cpp/relay_transform_sequential.cc 
b/tests/cpp/relay_transform_sequential.cc
index 756468c..d974f02 100644
--- a/tests/cpp/relay_transform_sequential.cc
+++ b/tests/cpp/relay_transform_sequential.cc
@@ -19,6 +19,7 @@
 
 #include <gtest/gtest.h>
 #include <topi/generic/injective.h>
+#include <tvm/node/structural_equal.h>
 #include <tvm/driver/driver_api.h>
 #include <tvm/relay/expr.h>
 #include <tvm/ir/module.h>
@@ -102,7 +103,7 @@ TEST(Relay, Sequential) {
   auto mod1 = IRModule::FromExpr(expected_func);
   mod1 = relay::transform::InferType()(mod1);
   auto expected = mod1->Lookup("main");
-  CHECK(relay::AlphaEqual(f, expected));
+  CHECK(tvm::StructuralEqual()(f, expected));
 }
 
 int main(int argc, char** argv) {
diff --git a/tests/python/frontend/caffe2/test_graph.py 
b/tests/python/frontend/caffe2/test_graph.py
index 35914ec..d64b133 100644
--- a/tests/python/frontend/caffe2/test_graph.py
+++ b/tests/python/frontend/caffe2/test_graph.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Test graph equality of caffe2 models."""
+import tvm
 from tvm import relay
 from tvm.relay import transform
 from model_zoo import c2_squeezenet, relay_squeezenet
@@ -23,7 +24,7 @@ from model_zoo import c2_squeezenet, relay_squeezenet
 def compare_graph(lhs_mod, rhs_mod):
     lhs_mod = transform.InferType()(lhs_mod)
     rhs_mod = transform.InferType()(rhs_mod)
-    assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"])
+    assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"])
 
 
 def test_squeeze_net():
diff --git a/tests/python/frontend/mxnet/test_graph.py 
b/tests/python/frontend/mxnet/test_graph.py
index 0008799..b7c01a5 100644
--- a/tests/python/frontend/mxnet/test_graph.py
+++ b/tests/python/frontend/mxnet/test_graph.py
@@ -25,7 +25,7 @@ import model_zoo
 def compare_graph(lhs_mod, rhs_mod):
     lhs_mod = transform.InferType()(lhs_mod)
     rhs_mod = transform.InferType()(rhs_mod)
-    assert relay.analysis.alpha_equal(lhs_mod["main"], rhs_mod["main"])
+    assert tvm.ir.structural_equal(lhs_mod["main"], rhs_mod["main"])
 
 def test_mlp():
     shape = {"data": (1, 1, 28, 28)}
diff --git a/tests/python/relay/test_analysis_extract_fused_functions.py 
b/tests/python/relay/test_analysis_extract_fused_functions.py
index 1a70ef1..dab481c 100644
--- a/tests/python/relay/test_analysis_extract_fused_functions.py
+++ b/tests/python/relay/test_analysis_extract_fused_functions.py
@@ -77,7 +77,7 @@ def test_extract_identity():
 
     mod["main"] = mod["main"].with_attr(
         "Primitive", tvm.tir.IntImm("int32", 1))
-    relay.analysis.assert_graph_equal(list(items.values())[0], mod["main"])
+    tvm.ir.structural_equal(list(items.values())[0], mod["main"])
 
 
 def test_extract_conv_net():
diff --git a/tests/python/relay/test_annotate_target.py 
b/tests/python/relay/test_annotate_target.py
index f4e602a..12a15dc 100644
--- a/tests/python/relay/test_annotate_target.py
+++ b/tests/python/relay/test_annotate_target.py
@@ -136,7 +136,7 @@ def test_extern_dnnl():
         mod = annotated(dtype, ishape, w1shape)
         mod = transform.AnnotateTarget("dnnl")(mod)
         ref_mod = expected(dtype, ishape, w1shape)
-        assert relay.analysis.alpha_equal(mod, ref_mod)
+        assert tvm.ir.structural_equal(mod, ref_mod)
 
     def test_run():
         if not tvm.get_global_func("relay.ext.dnnl", True):
diff --git a/tests/python/relay/test_call_graph.py 
b/tests/python/relay/test_call_graph.py
index 849f015..0af55d2 100644
--- a/tests/python/relay/test_call_graph.py
+++ b/tests/python/relay/test_call_graph.py
@@ -27,7 +27,7 @@ def test_callgraph_construct():
     mod["g1"] = relay.Function([x, y], x + y)
     call_graph = relay.analysis.CallGraph(mod)
     assert "g1" in str(call_graph)
-    assert relay.alpha_equal(mod, call_graph.module)
+    assert tvm.ir.structural_equal(mod, call_graph.module)
 
 
 def test_print_element():
diff --git a/tests/python/relay/test_ir_bind.py 
b/tests/python/relay/test_ir_bind.py
index 45474b6..8ba4644 100644
--- a/tests/python/relay/test_ir_bind.py
+++ b/tests/python/relay/test_ir_bind.py
@@ -29,11 +29,11 @@ def test_bind_params():
     fexpected =relay.Function(
         [y],
         relay.add(relay.const(1, "float32"),  y))
-    assert relay.analysis.alpha_equal(fbinded, fexpected)
+    assert tvm.ir.structural_equal(fbinded, fexpected)
 
     zbinded = relay.bind(z, {y: x})
     zexpected = relay.add(x, x)
-    assert relay.analysis.alpha_equal(zbinded, zexpected)
+    assert tvm.ir.structural_equal(zbinded, zexpected)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relay/test_ir_nodes.py 
b/tests/python/relay/test_ir_nodes.py
index 968a3bb..6d4a685 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -21,13 +21,12 @@ from tvm import te
 from tvm import relay
 from tvm.tir.expr import *
 from tvm.relay import op
-from tvm.relay.analysis import graph_equal
 import numpy as np
 
 def check_json_roundtrip(node):
     json_str = tvm.ir.save_json(node)
     back = tvm.ir.load_json(json_str)
-    assert graph_equal(back, node)
+    assert tvm.ir.structural_equal(back, node, map_free_vars=True)
 
 
 # Span
diff --git a/tests/python/relay/test_ir_structural_equal_hash.py 
b/tests/python/relay/test_ir_structural_equal_hash.py
index cf626d7..5295e17 100644
--- a/tests/python/relay/test_ir_structural_equal_hash.py
+++ b/tests/python/relay/test_ir_structural_equal_hash.py
@@ -107,7 +107,7 @@ def test_func_type_sequal():
     ft = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
                          tvm.runtime.convert([tp1, tp3]),
                          tvm.runtime.convert([tr1]))
-    translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp1,
+    translate_vars = relay.FuncType(tvm.runtime.convert([t1, t2]), tp2,
                          tvm.runtime.convert([tp2, tp4]),
                          tvm.runtime.convert([tr2]))
     assert ft == translate_vars
diff --git a/tests/python/relay/test_ir_text_printer.py 
b/tests/python/relay/test_ir_text_printer.py
index 49518a8..61dbca3 100644
--- a/tests/python/relay/test_ir_text_printer.py
+++ b/tests/python/relay/test_ir_text_printer.py
@@ -20,7 +20,7 @@ from tvm import relay
 import tvm.relay.testing
 import numpy as np
 from tvm.relay import Expr
-from tvm.relay.analysis import alpha_equal, assert_alpha_equal, 
assert_graph_equal, free_vars
+from tvm.relay.analysis import free_vars
 
 do_print = [False]
 
@@ -32,9 +32,9 @@ def astext(p, unify_free_vars=False):
         return txt
     x = relay.fromtext(txt)
     if unify_free_vars:
-        assert_graph_equal(x, p)
+        tvm.ir.assert_structural_equal(x, p, map_free_vars=True)
     else:
-        assert_alpha_equal(x, p)
+        tvm.ir.assert_structural_equal(x, p)
     return txt
 
 def show(text):
diff --git a/tests/python/relay/test_op_level10.py 
b/tests/python/relay/test_op_level10.py
index 953760c..30e2506 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -99,7 +99,7 @@ def test_checkpoint_alpha_equal():
         """
     )
 
-    relay.analysis.assert_alpha_equal(df, df_parsed)
+    tvm.ir.assert_structural_equal(df, df_parsed)
 
 def test_checkpoint_alpha_equal_tuple():
     xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i 
in range(4)]
@@ -146,7 +146,7 @@ def test_checkpoint_alpha_equal_tuple():
         """
     )
 
-    relay.analysis.assert_alpha_equal(df, df_parsed)
+    tvm.ir.assert_structural_equal(df, df_parsed)
 
 def test_collapse_sum_like():
     shape = (3, 4, 5, 6)
diff --git a/tests/python/relay/test_pass_alter_op_layout.py 
b/tests/python/relay/test_pass_alter_op_layout.py
index eabe758..a30492f 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -66,7 +66,7 @@ def test_alter_op():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_alter_return_none():
@@ -88,7 +88,7 @@ def test_alter_return_none():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(before(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
     assert(called[0])
 
 
@@ -151,7 +151,7 @@ def test_alter_layout():
                              transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_alter_layout_dual_path():
@@ -214,7 +214,7 @@ def test_alter_layout_dual_path():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 def test_alter_layout_resnet():
     """Test alternating the layout of a residual block
@@ -271,7 +271,7 @@ def test_alter_layout_resnet():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_alter_layout_broadcast_op():
@@ -318,7 +318,7 @@ def test_alter_layout_broadcast_op():
                              transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_alter_layout_broadcast_scalar_op():
@@ -381,7 +381,7 @@ def test_alter_layout_broadcast_scalar_op():
                              transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_alter_layout_scalar():
@@ -424,7 +424,7 @@ def test_alter_layout_scalar():
                              transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_alter_layout_concatenate():
@@ -478,7 +478,7 @@ def test_alter_layout_concatenate():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected_nchw(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
     # NHWC layout transformation.
     def before_nhwc():
@@ -524,7 +524,7 @@ def test_alter_layout_concatenate():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected_nhwc(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_alter_layout_nchw_upsamping_op():
@@ -561,7 +561,7 @@ def test_alter_layout_nchw_upsamping_op():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_alter_layout_strided_slice():
@@ -597,7 +597,7 @@ def test_alter_layout_strided_slice():
                              transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 def test_alter_layout_depthwise_conv2d():
     """Test depthwise_conv2d operator"""
@@ -632,7 +632,7 @@ def test_alter_layout_depthwise_conv2d():
                              transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert(analysis.alpha_equal(a, b))
+    assert(tvm.ir.structural_equal(a, b))
 
 def test_alter_layout_prelu():
     """Test PRelu operator"""
@@ -672,7 +672,7 @@ def test_alter_layout_prelu():
         a = run_opt_pass(a, [transform.CanonicalizeOps(), 
transform.AlterOpLayout()])
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert(analysis.alpha_equal(a, b))
+    assert(tvm.ir.structural_equal(a, b))
 
 
 def test_alter_layout_pad():
@@ -715,7 +715,7 @@ def test_alter_layout_pad():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected_nchw(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
     # Check NHWC conversion.
     def before_nhwc():
@@ -749,7 +749,7 @@ def test_alter_layout_pad():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected_nhwc(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
     # Check that conversion does not happen when padding along split axis.
     def before():
@@ -782,7 +782,7 @@ def test_alter_layout_pad():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_alter_layout_pool():
@@ -825,7 +825,7 @@ def test_alter_layout_pool():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected_nchw(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
     # Check NHWC conversion.
     def before_nhwc():
@@ -859,7 +859,7 @@ def test_alter_layout_pool():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected_nhwc(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_alter_layout_sum():
@@ -902,7 +902,7 @@ def test_alter_layout_sum():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected_nchw(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
     # Check NHWC conversion.
     def before_nhwc():
@@ -937,7 +937,7 @@ def test_alter_layout_sum():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected_nhwc(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 # TODO(@anijain2305, @icemelon9): We should fix this. This doesn't seem to be 
the
@@ -999,7 +999,7 @@ def test_alter_layout_nhwc_nchw_arm():
         a = run_opt_pass(a, transform.AlterOpLayout())
         b = run_opt_pass(expected_nhwc(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 def test_alter_op_with_global_var():
     """Test directly replacing an operator with a new one"""
@@ -1041,7 +1041,7 @@ def test_alter_op_with_global_var():
         a = transform.AlterOpLayout()(a)
         b = transform.InferType()(expected())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b, map_free_vars=True), "Actual = \n" + 
str(a)
 
 if __name__ == "__main__":
     test_alter_op()
diff --git a/tests/python/relay/test_pass_annotation.py 
b/tests/python/relay/test_pass_annotation.py
index 3e7d916..ea92546 100644
--- a/tests/python/relay/test_pass_annotation.py
+++ b/tests/python/relay/test_pass_annotation.py
@@ -64,7 +64,7 @@ def test_redundant_annotation():
 
     annotated_func = annotated()
     expected_func = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.alpha_equal(annotated_func, expected_func)
+    assert tvm.ir.structural_equal(annotated_func, expected_func)
 
 
 def test_annotate_expr():
@@ -91,7 +91,7 @@ def test_annotate_expr():
 
     annotated_expr = annotated()
     expected_expr = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.graph_equal(annotated_expr, expected_expr)
+    assert tvm.ir.structural_equal(annotated_expr, expected_expr)
 
 
 def test_annotate_all():
@@ -120,7 +120,7 @@ def test_annotate_all():
 
     annotated_func = annotated()
     expected_func = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.graph_equal(annotated_func, expected_func)
+    assert tvm.ir.structural_equal(annotated_func, expected_func)
 
 
 def test_annotate_none():
@@ -146,13 +146,13 @@ def test_annotate_none():
 
     annotated_func = annotated()
     expected_func = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.graph_equal(annotated_func, expected_func)
+    assert tvm.ir.structural_equal(annotated_func, expected_func)
 
 
 def check_annotated_graph(annotated_func, expected_func):
     annotated_func = run_opt_pass(annotated_func, transform.InferType())
     expected_func = run_opt_pass(expected_func, transform.InferType())
-    assert relay.analysis.alpha_equal(annotated_func, expected_func)
+    assert tvm.ir.structural_equal(annotated_func, expected_func)
 
 
 def test_conv_network():
@@ -596,7 +596,7 @@ def test_tuple_get_item():
 
     annotated_func = annotated()
     expected_func = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.graph_equal(annotated_func, expected_func)
+    assert tvm.ir.structural_equal(annotated_func, expected_func)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_canonicalize_cast.py 
b/tests/python/relay/test_pass_canonicalize_cast.py
index e9ab67f..7b6617a 100644
--- a/tests/python/relay/test_pass_canonicalize_cast.py
+++ b/tests/python/relay/test_pass_canonicalize_cast.py
@@ -64,7 +64,7 @@ def test_canonicalize_cast():
         mod[gv] = y_expected
         mod = _transform.InferType()(mod)
         y_expected = mod["expected"]
-        assert relay.analysis.alpha_equal(y, y_expected)
+        assert tvm.ir.structural_equal(y, y_expected)
 
     check((1, 16, 7, 7))
 
diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py 
b/tests/python/relay/test_pass_combine_parallel_conv2d.py
index ec9bcd9..345f068 100644
--- a/tests/python/relay/test_pass_combine_parallel_conv2d.py
+++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py
@@ -72,7 +72,7 @@ def test_combine_parallel_conv2d():
                          transform.CombineParallelConv2D(min_num_branches=2))
         y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, 
channels3, channels4)
         y_expected = run_opt_pass(y_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y, y_expected)
+        assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
 
     check((1, 4, 16, 16), 4, 4, 4, 4)
     check((1, 4, 16, 16), 4, 8, 4, 7)
@@ -118,7 +118,7 @@ def test_combine_parallel_conv2d_scale_relu():
                          transform.CombineParallelConv2D(min_num_branches=2))
         y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, 
channels2)
         y_expected = run_opt_pass(y_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y, y_expected)
+        assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
 
     check((1, 4, 16, 16), 4, 8)
 
@@ -157,7 +157,7 @@ def test_combine_parallel_conv2d_scale():
                          transform.CombineParallelConv2D(min_num_branches=2))
         y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2)
         y_expected = run_opt_pass(y_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y, y_expected)
+        assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
 
     check((1, 4, 16, 16), 4, 8)
 
@@ -193,7 +193,7 @@ def test_combine_parallel_conv2d_multiple_blocks():
                          transform.CombineParallelConv2D(min_num_branches=2))
         y_expected = expected(x, w, out_c, repeat)
         y_expected = run_opt_pass(y_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y, y_expected)
+        assert tvm.ir.structural_equal(y, y_expected, map_free_vars=True)
 
     check((1, 4, 16, 16), 4)
 
diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py 
b/tests/python/relay/test_pass_combine_parallel_dense.py
index 84d8211..f0f2e18 100644
--- a/tests/python/relay/test_pass_combine_parallel_dense.py
+++ b/tests/python/relay/test_pass_combine_parallel_dense.py
@@ -75,7 +75,7 @@ def test_combine_parallel_dense():
                          transform.CombineParallelDense(min_num_branches=2))
         y_expected = expected(x, w1, w2, w3, w4)
         y_expected = run_opt_pass(y_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y, y_expected)
+        tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
 
     check(3, 5, 4)
     check(100, 200, 300)
@@ -127,7 +127,7 @@ def test_combine_parallel_dense_biasadd():
                          transform.CombineParallelDense(min_num_branches=2))
         y_expected = expected(x, w1, w2, b1, b2, is_2d_bias)
         y_expected = run_opt_pass(y_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y, y_expected)
+        tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
 
     check(3, 5, 4, False)
     check(100, 200, 300, False)
@@ -184,7 +184,7 @@ def test_combine_parallel_dense_biasadd_scale_reshape():
                          transform.CombineParallelDense(min_num_branches=2))
         y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape)
         y_expected = run_opt_pass(y_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y, y_expected)
+        tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True)
 
     check(3, 5, 4, 0.5, 0.25, (1, 1, 15))
     check(100, 200, 300, 0.5, 0.25, (1, 1, 200))
diff --git a/tests/python/relay/test_pass_convert_op_layout.py 
b/tests/python/relay/test_pass_convert_op_layout.py
index 9e8f662..c783971 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -52,7 +52,7 @@ def test_no_convert_layout():
     a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
     b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_conv_convert_layout():
@@ -87,7 +87,7 @@ def test_conv_convert_layout():
     a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
     b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_conv_bias_pool_convert_layout():
@@ -132,7 +132,7 @@ def test_conv_bias_pool_convert_layout():
     a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
     b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_conv_concat_convert_layout():
@@ -180,7 +180,7 @@ def test_conv_concat_convert_layout():
     a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
     b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_dual_path_convert_layout():
@@ -235,7 +235,7 @@ def test_dual_path_convert_layout():
     a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
     b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_bn_convert_layout():
@@ -315,7 +315,7 @@ def test_resnet_convert_layout():
     a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
     b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_scalar_convert_layout():
@@ -347,7 +347,7 @@ def test_scalar_convert_layout():
     a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
     b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_conv_bn_convert_layout():
@@ -395,7 +395,7 @@ def test_conv_bn_convert_layout():
     a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
     b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_qnn_conv_requantize_convert_layout():
@@ -451,7 +451,7 @@ def test_qnn_conv_requantize_convert_layout():
     a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
     b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_qnn_conv_concat_convert_layout():
@@ -529,7 +529,7 @@ def test_qnn_conv_concat_convert_layout():
     a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
     b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_qnn_conv_add_convert_layout():
@@ -609,7 +609,7 @@ def test_qnn_conv_add_convert_layout():
     a = run_opt_pass(a, transform.ConvertLayout('NCHW'))
     b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_dead_code_elimination.py 
b/tests/python/relay/test_pass_dead_code_elimination.py
index 3a0bf1f..60dfa62 100644
--- a/tests/python/relay/test_pass_dead_code_elimination.py
+++ b/tests/python/relay/test_pass_dead_code_elimination.py
@@ -18,7 +18,7 @@ import tvm
 from tvm import te
 from tvm import relay
 from tvm.relay import Function, transform
-from tvm.relay.analysis import alpha_equal, graph_equal, free_vars, 
assert_alpha_equal
+from tvm.relay.analysis import free_vars
 from tvm.relay.op import log, add, equal, subtract
 from tvm.relay.testing import inception_v3
 
@@ -69,7 +69,7 @@ def test_used_let():
 def test_inline():
     orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
     orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
-    assert_alpha_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
+    tvm.ir.assert_structural_equal(Function(free_vars(orig), orig), 
Function([e.d], e.d))
 
 
 def test_chain_unused_let():
@@ -105,7 +105,7 @@ def test_recursion():
     orig = use_f(lambda f: relay.Call(f, [relay.const(2), 
relay.const(10000.0)]))
     dced = run_opt_pass(orig, transform.DeadCodeElimination())
     orig = run_opt_pass(orig, transform.InferType())
-    assert_alpha_equal(dced, orig)
+    tvm.ir.assert_structural_equal(dced, orig)
 
 def test_recursion_dead():
     x = relay.Let(e.a, e.one, e.three)
diff --git a/tests/python/relay/test_pass_eliminate_common_subexpr.py 
b/tests/python/relay/test_pass_eliminate_common_subexpr.py
index dddbef7..89e3b67 100644
--- a/tests/python/relay/test_pass_eliminate_common_subexpr.py
+++ b/tests/python/relay/test_pass_eliminate_common_subexpr.py
@@ -52,7 +52,7 @@ def test_simple():
 
     z = before()
     z = run_opt_pass(z, transform.EliminateCommonSubexpr())
-    assert analysis.alpha_equal(z, expected())
+    assert tvm.ir.structural_equal(z, expected())
 
 
 def test_callback():
@@ -82,7 +82,7 @@ def test_callback():
 
     z = before()
     z = run_opt_pass(z, transform.EliminateCommonSubexpr(fskip))
-    assert analysis.alpha_equal(z, expected())
+    assert tvm.ir.structural_equal(z, expected())
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_eta_expand.py 
b/tests/python/relay/test_pass_eta_expand.py
index ad04e41..84ff54a 100644
--- a/tests/python/relay/test_pass_eta_expand.py
+++ b/tests/python/relay/test_pass_eta_expand.py
@@ -47,7 +47,8 @@ def test_eta_expand_global_var():
             }
         }
     """)
-    relay.analysis.assert_graph_equal(mod['main'], expected['main'])
+    tvm.ir.assert_structural_equal(mod['main'], expected['main'],
+                                   map_free_vars=True)
 
 
 def test_eta_expand_constructor():
@@ -76,7 +77,8 @@ def test_eta_expand_constructor():
             }
         }
     """)
-    relay.analysis.assert_graph_equal(mod['main'], expected['main'])
+    tvm.ir.assert_structural_equal(mod['main'], expected['main'],
+                                   map_free_vars=True)
 
 
 if __name__ == '__main__':
diff --git a/tests/python/relay/test_pass_fold_constant.py 
b/tests/python/relay/test_pass_fold_constant.py
index cc362a2..3ddafd7 100644
--- a/tests/python/relay/test_pass_fold_constant.py
+++ b/tests/python/relay/test_pass_fold_constant.py
@@ -59,7 +59,7 @@ def test_fold_const():
         with tvm.target.create("cuda"):
             zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.alpha_equal(zz, zexpected)
+    assert tvm.ir.structural_equal(zz, zexpected)
 
 
 def test_fold_let():
@@ -84,7 +84,7 @@ def test_fold_let():
 
     zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.graph_equal(zz, zexpected)
+    assert tvm.ir.structural_equal(zz, zexpected)
 
 
 def test_fold_tuple():
@@ -106,7 +106,7 @@ def test_fold_tuple():
 
     zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.graph_equal(zz, zexpected)
+    assert tvm.ir.structural_equal(zz, zexpected)
 
 
 def test_fold_concat():
@@ -125,7 +125,7 @@ def test_fold_concat():
 
     zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.graph_equal(zz, zexpected)
+    assert tvm.ir.structural_equal(zz, zexpected)
 
 
 def test_fold_shape_of():
@@ -146,7 +146,7 @@ def test_fold_shape_of():
     for dtype in ["int32", "float32"]:
         zz = run_opt_pass(before(dtype), transform.FoldConstant())
         zexpected = run_opt_pass(expected(dtype), transform.InferType())
-        assert relay.analysis.graph_equal(zz, zexpected)
+        assert tvm.ir.structural_equal(zz, zexpected)
 
 
 def test_fold_full():
@@ -161,7 +161,7 @@ def test_fold_full():
 
     zz = run_opt_pass(before(), transform.FoldConstant())
     zexpected = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.graph_equal(zz, zexpected)
+    assert tvm.ir.structural_equal(zz, zexpected)
 
 
 def test_fold_batch_norm():
@@ -202,7 +202,7 @@ def test_fold_batch_norm():
         mod = remove_bn_pass(mod)
 
     expect = run_infer_type(expected())
-    assert relay.analysis.graph_equal(mod["main"], expect)
+    assert tvm.ir.structural_equal(mod["main"], expect)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_fold_scale_axis.py 
b/tests/python/relay/test_pass_fold_scale_axis.py
index 4c094fb..bf2a708 100644
--- a/tests/python/relay/test_pass_fold_scale_axis.py
+++ b/tests/python/relay/test_pass_fold_scale_axis.py
@@ -79,7 +79,7 @@ def test_fold_fwd_simple():
 
         y1_folded = run_opt_pass(y1_folded, transform.InferType())
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+        assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
     check((2, 4, 10, 10), 2)
 
@@ -148,7 +148,7 @@ def test_fold_fwd_dual_path():
         weight = relay.var("weight", type_dict["weight"])
         y1_expected = expected(x, weight, in_bias, in_scale, channels)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+        assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
     check((2, 4, 10, 3), 3)
 
@@ -177,7 +177,7 @@ def test_fold_fwd_fail():
         y1 = before(x, weight, in_bias, in_scale, channels)
         y1 = run_opt_pass(y1, transform.InferType())
         y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
-        assert relay.analysis.alpha_equal(y1, y1_folded)
+        assert tvm.ir.structural_equal(y1, y1_folded)
 
     check((2, 11, 10, 4), 4)
 
@@ -205,7 +205,7 @@ def test_fold_fwd_relu_fail():
         y1 = before(x, weight, in_bias, in_scale, channels)
         y1 = run_opt_pass(y1, transform.InferType())
         y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
-        assert relay.analysis.alpha_equal(y1, y1_folded)
+        assert tvm.ir.structural_equal(y1, y1_folded)
 
     in_scale = relay.var("in_scale", shape=(4,))
     check((2, 11, 10, 4), 4, in_scale)
@@ -249,7 +249,7 @@ def test_fold_fwd_negative_scale():
         y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis())
         y1_expected = expected(x, weight, in_scale, channels)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+        assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
     check((2, 4, 10, 10), 4)
 
@@ -300,7 +300,7 @@ def test_fold_bwd_simple():
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
         y1_expected = expected(x, weight, out_bias, out_scale, channels)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+        assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
     check((2, 4, 10, 10), 8)
 
@@ -359,7 +359,7 @@ def test_fold_bwd_dual_path():
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
         y1_expected = expected(x, weight, out_bias, out_scale, channels)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+        assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
     check((2, 4, 10, 10), 8)
 
@@ -431,7 +431,7 @@ def test_fold_bwd_dual_consumer():
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
         y1_expected = expected(x, weight, out_bias, out_scale, channels)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+        assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
     check((2, 4, 10, 10), 4)
 
@@ -480,7 +480,7 @@ def test_fold_bwd_fail():
         y1 = fbefore(x, weight, out_bias, out_scale, channels)
         y1 = run_opt_pass(y1, transform.InferType())
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
-        assert relay.analysis.alpha_equal(y1_folded, y1)
+        assert tvm.ir.structural_equal(y1_folded, y1)
 
     check((4, 4, 10, 10), 4, fail1)
     check((4, 4, 10, 10), 4, fail2)
@@ -505,7 +505,7 @@ def test_fold_bwd_relu_fail():
         y1 = before(x, weight, out_scale, channels)
         y1 = run_opt_pass(y1, transform.InferType())
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
-        assert relay.analysis.alpha_equal(y1, y1_folded)
+        assert tvm.ir.structural_equal(y1, y1_folded)
 
     out_scale = relay.var("in_scale", shape=(4, 1, 1))
     check((4, 4, 10, 10), 4, out_scale)
@@ -547,7 +547,7 @@ def test_fold_bwd_negative_scale():
         y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis())
         y1_expected = expected(x, weight, out_scale, channels)
         y1_expected = run_opt_pass(y1_expected, transform.InferType())
-        assert relay.analysis.alpha_equal(y1_folded, y1_expected)
+        assert tvm.ir.structural_equal(y1_folded, y1_expected)
 
     check((2, 4, 10, 10), 8)
 
diff --git a/tests/python/relay/test_pass_fuse_ops.py 
b/tests/python/relay/test_pass_fuse_ops.py
index 108c91b..6b7d297 100644
--- a/tests/python/relay/test_pass_fuse_ops.py
+++ b/tests/python/relay/test_pass_fuse_ops.py
@@ -45,7 +45,7 @@ def test_fuse_simple():
     zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
     zz = run_opt_pass(z, transform.FuseOps())
     after = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.alpha_equal(zz, after)
+    assert tvm.ir.structural_equal(zz, after)
 
 
 def test_conv2d_fuse():
@@ -127,7 +127,7 @@ def test_conv2d_fuse():
     z = before(dshape)
     zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
     after = run_opt_pass(expected(dshape), transform.InferType())
-    assert relay.analysis.alpha_equal(zz, after)
+    assert tvm.ir.structural_equal(zz, after)
 
 
 def test_concatenate():
@@ -167,7 +167,7 @@ def test_concatenate():
     zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
     assert not relay.analysis.free_vars(zz)
     after = run_opt_pass(expected(dshape), transform.InferType())
-    assert relay.analysis.alpha_equal(zz, after)
+    assert tvm.ir.structural_equal(zz, after)
 
 
 def test_tuple_root():
@@ -204,7 +204,7 @@ def test_tuple_root():
     zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
     assert not relay.analysis.free_vars(zz)
     after = run_opt_pass(expected(dshape), transform.InferType())
-    assert relay.analysis.alpha_equal(zz, after)
+    assert tvm.ir.structural_equal(zz, after)
 
 
 def test_stop_fusion():
@@ -235,7 +235,7 @@ def test_stop_fusion():
     z = before(dshape)
     zz = run_opt_pass(z, transform.FuseOps())
     after = run_opt_pass(expected(dshape), transform.InferType())
-    assert relay.analysis.alpha_equal(zz, after)
+    assert tvm.ir.structural_equal(zz, after)
 
 
 def test_fuse_myia_regression():
@@ -271,7 +271,7 @@ def test_fuse_myia_regression():
     f = before(dshape, dtype)
     zz = run_opt_pass(f, transform.FuseOps())
     after = run_opt_pass(expected(dshape, dtype), transform.InferType())
-    assert relay.analysis.alpha_equal(zz, after)
+    assert tvm.ir.structural_equal(zz, after)
 
 
 def test_fuse_tuple_get_elemwise():
@@ -309,7 +309,7 @@ def test_fuse_tuple_get_elemwise():
     zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
     assert not relay.analysis.free_vars(zz)
     after = run_opt_pass(expected(dim), transform.InferType())
-    assert relay.analysis.alpha_equal(zz, after)
+    assert tvm.ir.structural_equal(zz, after)
 
 
 def test_tuple_get_root():
@@ -346,7 +346,7 @@ def test_tuple_get_root():
     zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
     assert not relay.analysis.free_vars(zz)
     after = run_opt_pass(expected(dim), transform.InferType())
-    assert relay.analysis.alpha_equal(zz, after)
+    assert tvm.ir.structural_equal(zz, after)
 
 
 fuse0 = relay.transform.FuseOps(fuse_opt_level=0)
@@ -379,7 +379,7 @@ def test_tuple_intermediate():
     m = fuse2(tvm.IRModule.from_expr(orig))
     relay.build(m, 'llvm')
     after = run_opt_pass(expected(x), transform.InferType())
-    assert relay.analysis.alpha_equal(m["main"], after)
+    assert tvm.ir.structural_equal(m["main"], after)
 
 
 def test_tuple_consecutive():
@@ -437,7 +437,7 @@ def test_tuple_consecutive():
     m = fuse2(tvm.IRModule.from_expr(orig))
     relay.build(m, 'llvm')
     after = run_opt_pass(expected(dshape), transform.InferType())
-    assert relay.analysis.alpha_equal(m["main"], after)
+    assert tvm.ir.structural_equal(m["main"], after)
 
 
 def test_inception_like():
@@ -510,7 +510,7 @@ def test_inception_like():
     m = fuse2(tvm.IRModule.from_expr(orig))
     relay.build(m, 'llvm')
     after = run_opt_pass(expected(dshape), transform.InferType())
-    assert relay.analysis.alpha_equal(m["main"], after)
+    assert tvm.ir.structural_equal(m["main"], after)
 
 
 def test_fuse_parallel_injective():
@@ -541,7 +541,7 @@ def test_fuse_parallel_injective():
     zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
     assert not relay.analysis.free_vars(zz)
     after = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.alpha_equal(zz, after)
+    assert tvm.ir.structural_equal(zz, after)
 
 
 def test_immutable():
@@ -570,8 +570,8 @@ def test_immutable():
 
     mod = before()
     new_mod = transform.FuseOps(fuse_opt_level=2)(mod)
-    assert relay.analysis.alpha_equal(mod, before())
-    assert relay.analysis.alpha_equal(new_mod, expected())
+    assert tvm.ir.structural_equal(mod, before())
+    assert tvm.ir.structural_equal(new_mod, expected())
 
 
 def test_split():
@@ -619,7 +619,7 @@ def test_fuse_max():
     zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2))
     zz = run_opt_pass(z, transform.FuseOps())
     after = run_opt_pass(expected(), transform.InferType())
-    assert relay.analysis.alpha_equal(zz, after)
+    assert tvm.ir.structural_equal(zz, after)
 
 if __name__ == "__main__":
     test_fuse_simple()
diff --git a/tests/python/relay/test_pass_gradient.py 
b/tests/python/relay/test_pass_gradient.py
index 6f2a125..efd01cb 100644
--- a/tests/python/relay/test_pass_gradient.py
+++ b/tests/python/relay/test_pass_gradient.py
@@ -19,7 +19,7 @@ import numpy as np
 import tvm
 from tvm import te
 from tvm import relay
-from tvm.relay.analysis import free_vars, free_type_vars, assert_alpha_equal
+from tvm.relay.analysis import free_vars, free_type_vars
 from tvm.relay import create_executor, transform
 from tvm.relay.transform import gradient
 from tvm.relay.prelude import Prelude
@@ -292,7 +292,7 @@ def test_concat():
     func = relay.Function([x], y)
     func = run_infer_type(func)
     back_func = run_infer_type(gradient(func))
-    assert_alpha_equal(back_func.checked_type, relay.FuncType([t], 
relay.TupleType([rt, relay.TupleType([t])])))
+    tvm.ir.assert_structural_equal(back_func.checked_type, relay.FuncType([t], 
relay.TupleType([rt, relay.TupleType([t])])))
     # no value validation as concatenate has dummy gradient right now.
 
 
diff --git a/tests/python/relay/test_pass_inline.py 
b/tests/python/relay/test_pass_inline.py
index f4943ab..0f6d539 100644
--- a/tests/python/relay/test_pass_inline.py
+++ b/tests/python/relay/test_pass_inline.py
@@ -115,7 +115,7 @@ def test_call_chain_inline_leaf():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, expected())
+    assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
 
 def test_call_chain_inline_multiple_levels():
@@ -188,7 +188,7 @@ def test_call_chain_inline_multiple_levels():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, expected())
+    assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
 
 def test_call_chain_inline_multiple_levels_extern_compiler():
@@ -266,7 +266,7 @@ def 
test_call_chain_inline_multiple_levels_extern_compiler():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, expected())
+    assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
 
 def test_recursive_call_with_global():
@@ -321,7 +321,7 @@ def test_recursive_call_with_global():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, expected())
+    assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
 
 def test_recursive_called():
@@ -330,7 +330,7 @@ def test_recursive_called():
     mod["main"] = relay.Function([iarg], sum_up(iarg))
     ref_mod = mod
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, ref_mod)
+    assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
 
 
 def test_recursive_not_called():
@@ -356,7 +356,7 @@ def test_recursive_not_called():
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
     ref_mod = expected()
-    assert relay.analysis.alpha_equal(mod, ref_mod)
+    assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
 
 
 def test_recursive_not_called_extern_compiler():
@@ -387,7 +387,7 @@ def test_recursive_not_called_extern_compiler():
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
     ref_mod = expected()
-    assert relay.analysis.alpha_equal(mod, ref_mod)
+    assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
 
 
 def test_globalvar_as_call_arg():
@@ -434,7 +434,7 @@ def test_globalvar_as_call_arg():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, expected())
+    assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
 
 def test_globalvar_as_call_arg_extern_compiler():
@@ -500,7 +500,7 @@ def test_globalvar_as_call_arg_extern_compiler():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, expected())
+    assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
 
 def test_inline_globalvar_without_args():
@@ -531,7 +531,7 @@ def test_inline_globalvar_without_args():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, expected())
+    assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
 
 def test_inline_globalvar_without_args_extern_compiler():
@@ -566,7 +566,7 @@ def test_inline_globalvar_without_args_extern_compiler():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, expected())
+    assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
 
 def test_globalvar_called_by_multiple_functions():
@@ -644,7 +644,7 @@ def test_globalvar_called_by_multiple_functions():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, expected())
+    assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
 
 def test_entry_with_inline():
@@ -674,7 +674,7 @@ def test_entry_with_inline():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, get_mod())
+    assert tvm.ir.structural_equal(mod, get_mod(), map_free_vars=True)
 
 
 def test_callee_not_inline():
@@ -707,7 +707,7 @@ def test_callee_not_inline():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, get_mod())
+    assert tvm.ir.structural_equal(mod, get_mod(), map_free_vars=True)
 
 
 def test_callee_not_inline_leaf_inline():
@@ -765,7 +765,7 @@ def test_callee_not_inline_leaf_inline():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, expected())
+    assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
 
 def test_callee_not_inline_leaf_inline_extern_compiler():
@@ -830,7 +830,7 @@ def test_callee_not_inline_leaf_inline_extern_compiler():
 
     mod = get_mod()
     mod = relay.transform.Inline()(mod)
-    assert relay.analysis.alpha_equal(mod, expected())
+    assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
 
 if __name__ == '__main__':
diff --git a/tests/python/relay/test_pass_legalize.py 
b/tests/python/relay/test_pass_legalize.py
index 9976eca..1456700 100644
--- a/tests/python/relay/test_pass_legalize.py
+++ b/tests/python/relay/test_pass_legalize.py
@@ -68,7 +68,7 @@ def test_legalize():
         a = run_opt_pass(a, transform.Legalize())
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 def test_legalize_none():
     """Test doing nothing by returning 'None' """
@@ -89,7 +89,7 @@ def test_legalize_none():
         a = run_opt_pass(a, transform.Legalize())
         b = run_opt_pass(before(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
     assert(called[0])
 
 def test_legalize_multiple_ops():
@@ -134,7 +134,7 @@ def test_legalize_multiple_ops():
             a = run_opt_pass(a, transform.Legalize())
             b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_legalize_multi_input():
@@ -170,7 +170,7 @@ def test_legalize_multi_input():
         a = run_opt_pass(a, transform.Legalize())
         b = run_opt_pass(expected(), transform.InferType())
 
-    assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_manager.py 
b/tests/python/relay/test_pass_manager.py
index aed0269..0a6555b 100644
--- a/tests/python/relay/test_pass_manager.py
+++ b/tests/python/relay/test_pass_manager.py
@@ -111,7 +111,7 @@ def get_rand(shape, dtype='float32'):
 def check_func(func, ref_func):
     func = run_infer_type(func)
     ref_func = run_infer_type(ref_func)
-    assert analysis.graph_equal(func, ref_func)
+    assert tvm.ir.structural_equal(func, ref_func)
 
 
 def test_module_pass():
@@ -211,7 +211,7 @@ def test_function_class_pass():
     mod = fpass(mod)
     # wrap in expr
     mod2 = tvm.IRModule.from_expr(f1)
-    assert relay.alpha_equal(mod["main"], mod2["main"])
+    assert tvm.ir.structural_equal(mod["main"], mod2["main"])
 
 
 def test_function_pass():
@@ -496,7 +496,7 @@ def test_sequential_with_scoping():
 
     zz = mod["main"]
     zexpected = run_infer_type(expected())
-    assert analysis.alpha_equal(zz, zexpected)
+    assert tvm.ir.structural_equal(zz, zexpected)
 
 
 def test_print_ir(capfd):
diff --git a/tests/python/relay/test_pass_merge_composite.py 
b/tests/python/relay/test_pass_merge_composite.py
index 72ed3fc..3c70cf2 100644
--- a/tests/python/relay/test_pass_merge_composite.py
+++ b/tests/python/relay/test_pass_merge_composite.py
@@ -15,6 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Unit tests for merge composite."""
+import tvm
 from tvm import relay
 from tvm import tir
 from tvm.relay.testing import run_opt_pass
@@ -192,7 +193,7 @@ def test_simple_merge():
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
     expected = run_opt_pass(expected(), relay.transform.InferType())
-    assert relay.analysis.alpha_equal(result, expected)
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
 
 
 def test_branch_merge():
@@ -270,7 +271,7 @@ def test_branch_merge():
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
     expected = run_opt_pass(expected(), relay.transform.InferType())
-    assert relay.analysis.alpha_equal(result, expected)
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
 
 
 def test_reuse_call_merge():
@@ -329,7 +330,7 @@ def test_reuse_call_merge():
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
     expected = run_opt_pass(expected(), relay.transform.InferType())
-    assert relay.analysis.alpha_equal(result, expected)
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
 
 
 def test_multiple_patterns():
@@ -422,7 +423,7 @@ def test_multiple_patterns():
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
     expected = run_opt_pass(expected(), relay.transform.InferType())
-    assert relay.analysis.alpha_equal(result, expected)
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
 
 
 def test_merge_order():
@@ -494,7 +495,7 @@ def test_merge_order():
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
     expected = run_opt_pass(after_A_priority("A"), relay.transform.InferType())
-    assert relay.analysis.alpha_equal(result, expected)
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
 
     # check B highest priority
     pattern_table = [
@@ -505,7 +506,7 @@ def test_merge_order():
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
     expected = run_opt_pass(after_A_priority("B"), relay.transform.InferType())
-    assert relay.analysis.alpha_equal(result, expected)
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
 
     # check C highest priority
     pattern_table = [
@@ -516,7 +517,7 @@ def test_merge_order():
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
     expected = run_opt_pass(after_A_priority("C"), relay.transform.InferType())
-    assert relay.analysis.alpha_equal(result, expected)
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
 
 
 def test_parallel_merge():
@@ -563,7 +564,7 @@ def test_parallel_merge():
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
     expected = run_opt_pass(after(), relay.transform.InferType())
-    assert relay.analysis.alpha_equal(result, expected)
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
 
 
 def test_multiple_input_subgraphs():
@@ -676,13 +677,13 @@ def test_multiple_input_subgraphs():
     result = run_opt_pass(before()['A'], 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
     expected = run_opt_pass(after_A(), relay.transform.InferType())
-    assert relay.analysis.alpha_equal(result, expected)
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
 
     # check case 'B'
     result = run_opt_pass(before()['B'], 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
     expected = run_opt_pass(after_B(), relay.transform.InferType())
-    assert relay.analysis.alpha_equal(result, expected)
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
 
 
 def test_tuple_get_item_merge():
@@ -728,7 +729,7 @@ def test_tuple_get_item_merge():
     result = run_opt_pass(before(), 
relay.transform.MergeComposite(pattern_table))
     assert not relay.analysis.free_vars(result)
     expected = run_opt_pass(expected(), relay.transform.InferType())
-    assert relay.analysis.alpha_equal(result, expected)
+    assert tvm.ir.structural_equal(result, expected, map_free_vars=True)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relay/test_pass_partial_eval.py 
b/tests/python/relay/test_pass_partial_eval.py
index 1299084..0f3eea6 100644
--- a/tests/python/relay/test_pass_partial_eval.py
+++ b/tests/python/relay/test_pass_partial_eval.py
@@ -19,7 +19,6 @@ import numpy as np
 import tvm
 from tvm import te
 from tvm import relay
-from tvm.relay.analysis import alpha_equal, assert_alpha_equal
 from tvm.relay.prelude import Prelude
 from tvm.relay import op, create_executor, transform
 from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, 
RefRead, RefWrite, RefCreate
@@ -124,7 +123,7 @@ def test_ad():
     body = relay.Let(x1, o, body)
     expected = Function([d], relay.Let(x, m, body))
     expected = run_opt_pass(expected, transform.InferType())
-    assert_alpha_equal(g, expected)
+    tvm.ir.assert_structural_equal(g, expected)
 
 
 def test_if_ref():
@@ -312,7 +311,7 @@ def test_concat():
     x = Var("x", t)
     y = Var("x", t)
     orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0)))
-    assert_alpha_equal(dcpe(orig), orig)
+    tvm.ir.assert_structural_equal(dcpe(orig), orig)
 
 
 def test_triangle_number():
@@ -321,7 +320,7 @@ def test_triangle_number():
     f_var = Var("f")
     f = Function([x], If(op.equal(x, const(0)), const(0), x + f_var(x - 
const(1))))
     orig = run_infer_type(Let(f_var, f, f_var(const(10))))
-    assert_alpha_equal(dcpe(orig), const(55))
+    tvm.ir.assert_structural_equal(dcpe(orig), const(55))
 
 
 def test_nat_update():
@@ -337,7 +336,7 @@ def test_tuple_match():
     b = relay.Var("b")
     clause = relay.Clause(relay.PatternTuple([relay.PatternVar(a), 
relay.PatternVar(b)]), a + b)
     x = relay.Match(relay.Tuple([relay.const(1), relay.const(1)]), [clause])
-    assert_alpha_equal(dcpe(x), const(2))
+    tvm.ir.assert_structural_equal(dcpe(x), const(2))
 
 
 if __name__ == '__main__':
diff --git a/tests/python/relay/test_pass_partition_graph.py 
b/tests/python/relay/test_pass_partition_graph.py
index 1f37ab8..fc8dfb6 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -339,7 +339,7 @@ def test_extern_ccompiler_default_ops():
 
     fused_mod = transform.FuseOps(2)(mod)
     expected_mod = expected()
-    assert relay.alpha_equal(fused_mod, expected_mod)
+    assert tvm.ir.structural_equal(fused_mod, expected_mod, map_free_vars=True)
 
     x_data = np.random.rand(8, 8).astype('float32')
     y_data = np.random.rand(8, 8).astype('float32')
@@ -427,7 +427,7 @@ def test_extern_dnnl():
     mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func())
     mod = transform.PartitionGraph()(mod)
 
-    assert relay.alpha_equal(mod, expected())
+    assert tvm.ir.structural_equal(mod, expected(), map_free_vars=True)
 
     ref_mod = tvm.IRModule()
     ref_mod["main"] = get_func()
@@ -561,7 +561,7 @@ def test_function_lifting():
 
     partitioned = partition()
     ref_mod = expected()
-    assert relay.analysis.alpha_equal(partitioned, ref_mod)
+    assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
 
 
 def test_function_lifting_inline():
@@ -631,7 +631,7 @@ def test_function_lifting_inline():
 
     partitioned = partition()
     ref_mod = expected()
-    assert relay.analysis.alpha_equal(partitioned, ref_mod)
+    assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
 
 
 def test_constant_propagation():
@@ -671,7 +671,7 @@ def test_constant_propagation():
     mod = transform.PartitionGraph()(mod)
 
     expected_mod = expected()
-    assert relay.alpha_equal(mod, expected_mod)
+    assert tvm.ir.structural_equal(mod, expected_mod, map_free_vars=True)
 
     y_data = np.random.rand(8, 8).astype('float32')
     np_add = ones + y_data
diff --git a/tests/python/relay/test_pass_qnn_legalize.py 
b/tests/python/relay/test_pass_qnn_legalize.py
index b164821..e7980e7 100644
--- a/tests/python/relay/test_pass_qnn_legalize.py
+++ b/tests/python/relay/test_pass_qnn_legalize.py
@@ -31,7 +31,7 @@ def alpha_equal(x, y):
     """
     x = x['main']
     y = y['main']
-    return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == 
analysis.structural_hash(y)
+    return tvm.ir.structural_equal(x, y) and analysis.structural_hash(x) == 
analysis.structural_hash(y)
 
 def run_opt_pass(expr, passes):
     passes = passes if isinstance(passes, list) else [passes]
@@ -85,12 +85,12 @@ def test_qnn_legalize():
         # Check that Relay Legalize does not change the graph.
         a = run_opt_pass(a, relay.transform.Legalize())
         b = run_opt_pass(before(), transform.InferType())
-        assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+        assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
         # Check that QNN Legalize modifies the graph.
         a = run_opt_pass(a, relay.qnn.transform.Legalize())
         b = run_opt_pass(expected(), transform.InferType())
-        assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
+        assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
 def test_qnn_legalize_qnn_conv2d():
diff --git a/tests/python/relay/test_pass_remove_unused_functions.py 
b/tests/python/relay/test_pass_remove_unused_functions.py
index 3381634..43b54e9 100644
--- a/tests/python/relay/test_pass_remove_unused_functions.py
+++ b/tests/python/relay/test_pass_remove_unused_functions.py
@@ -110,7 +110,7 @@ def test_call_globalvar_without_args():
     mod = get_mod()
     ref_mod = get_mod()
     mod = relay.transform.RemoveUnusedFunctions()(mod)
-    assert relay.alpha_equal(mod, ref_mod)
+    assert tvm.ir.structural_equal(mod, ref_mod, map_free_vars=True)
 
 
 if __name__ == '__main__':
diff --git a/tests/python/relay/test_pass_simplify_inference.py 
b/tests/python/relay/test_pass_simplify_inference.py
index bb39893..3a8c90b 100644
--- a/tests/python/relay/test_pass_simplify_inference.py
+++ b/tests/python/relay/test_pass_simplify_inference.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from tvm.ir import IRModule
+from tvm.ir import IRModule, structural_equal
 from tvm import relay as rly
 from tvm.relay.transform import SimplifyInference
 
@@ -56,7 +56,7 @@ def test_simplify_batchnorm(dtype='float32'):
         mod = simplify(mod)
         y1 = mod["main"].body
 
-        assert rly.analysis.graph_equal(y1, y2)
+        assert structural_equal(y1, y2, map_free_vars=True)
 
     check(2, 1, 1)
     check(4, 1, 1)
diff --git a/tests/python/relay/test_pass_to_a_normal_form.py 
b/tests/python/relay/test_pass_to_a_normal_form.py
index 29818f8..d7babf3 100644
--- a/tests/python/relay/test_pass_to_a_normal_form.py
+++ b/tests/python/relay/test_pass_to_a_normal_form.py
@@ -18,7 +18,7 @@ import numpy as np
 import tvm
 from tvm import te
 from tvm import relay
-from tvm.relay.analysis import alpha_equal, detect_feature
+from tvm.relay.analysis import detect_feature
 from tvm.relay import op, create_executor, transform
 from tvm.relay.prelude import Prelude
 from tvm.relay.testing import add_nat_definitions, count
diff --git a/tests/python/relay/test_type_functor.py 
b/tests/python/relay/test_type_functor.py
index 9e023bc..b90a688 100644
--- a/tests/python/relay/test_type_functor.py
+++ b/tests/python/relay/test_type_functor.py
@@ -18,7 +18,6 @@ import tvm
 from tvm import te
 from tvm import relay
 from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor
-from tvm.relay.analysis import assert_graph_equal
 from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType,
                  TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall)
 from tvm.relay.adt import TypeData
@@ -34,7 +33,8 @@ def check_visit(typ):
     ev = TypeVisitor()
     ev.visit(typ)
 
-    assert_graph_equal(TypeMutator().visit(typ), typ)
+    tvm.ir.assert_structural_equal(TypeMutator().visit(typ), typ,
+                                   map_free_vars=True)
 
 
 def test_type_var():
diff --git a/tests/python/unittest/test_ir_type.py 
b/tests/python/unittest/test_ir_type.py
index f919f92..a0e7d2b 100644
--- a/tests/python/unittest/test_ir_type.py
+++ b/tests/python/unittest/test_ir_type.py
@@ -18,10 +18,9 @@
 import tvm
 
 def check_json_roundtrip(node):
-    from tvm.relay.analysis import graph_equal
     json_str = tvm.ir.save_json(node)
     back = tvm.ir.load_json(json_str)
-    assert graph_equal(back, node)
+    assert tvm.ir.structural_equal(back, node, map_free_vars=True)
 
 
 def test_prim_type():

Reply via email to