This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 85b8a419ab [Unity][Pass] Operator Fusion Passes (#14001)
85b8a419ab is described below

commit 85b8a419ab7a7979367bc9c85642a7c9a2df6e54
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Feb 15 21:35:14 2023 +0800

    [Unity][Pass] Operator Fusion Passes (#14001)
    
    [Unity][Pass] Operator fusion passes
    
    This PR introduces three passes for operator fusion:
    1. AnnotateTIROpPattern: analysis the operator kind from PrimFunc.
    2. FuseOps: fuse operators for Relax functions, which adds a new fused
    relax primitive function.
    3. FuseTIR: fuse corresponding TIR PrimFuncs for the fused relax.
---
 include/tvm/relax/analysis.h                       |  11 +
 include/tvm/tir/buffer.h                           |  14 +-
 python/tvm/relax/transform/transform.py            |  43 +
 src/relax/transform/annotate_tir_op_pattern.cc     |  55 ++
 src/relax/transform/fuse_ops.cc                    | 909 +++++++++++++++++++++
 src/relax/transform/fuse_tir.cc                    | 728 +++++++++++++++++
 .../test_transform_annotate_tir_op_pattern.py      | 360 ++++++++
 tests/python/relax/test_transform_fuse_ops.py      | 759 +++++++++++++++++
 tests/python/relax/test_transform_fuse_tir.py      | 563 +++++++++++++
 tests/python/relax/test_tvmscript_parser.py        |   1 -
 10 files changed, 3441 insertions(+), 2 deletions(-)

diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index 24cfe5b9bf..a55fe6797d 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -260,6 +260,17 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const 
StructInfo& derived,
 TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
                                  arith::Analyzer* ana = nullptr);
 
+/*!
+ * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax 
FuseOps.
+ *
+ * \param func The PrimFunc to be analyzed.
+ * \return The Op Pattern Kind.
+ *
+ * \note This analysis applies on TIR function but is primarily used by relax 
passes.
+ *       As a result we place it under the relax namespace.
+ */
+TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func);
+
 /*!
  * \brief Check if the given PrimFunc is essentially doing a reshape operation.
  * The reshape operation also includes expand_dims, squeeze, flatten, etc.
diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index d7a2aec0b9..e3a853e4c7 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -34,6 +34,18 @@
 namespace tvm {
 namespace tir {
 
+#ifndef TVM_INDEX_DEFAULT_I64
+#define TVM_INDEX_DEFAULT_I64 1
+#endif
+/*! \brief if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return 
int32 */
+inline DataType DefaultIndexType() {
+#if TVM_INDEX_DEFAULT_I64
+  return DataType::Int(64);
+#else
+  return DataType::Int(32);
+#endif
+}
+
 // forward declare Stmt
 class Stmt;
 
@@ -135,7 +147,7 @@ class BufferNode : public Object {
 
   /*! \return preferred index type for this buffer node */
   DataType DefaultIndexType() const {
-    return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
+    return shape.size() != 0 ? shape[0].dtype() : tvm::tir::DefaultIndexType();
   }
 
   /*! \brief Determine the offset in the buffer of the given index.
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index cab18797c6..0f973db290 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -105,6 +105,49 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass:
     return _ffi_api.AttachGlobalSymbol()  # type: ignore
 
 
+def AnnotateTIROpPattern() -> tvm.ir.transform.Pass:
+    """Annotate Op Pattern Kind for TIR functions
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    return _ffi_api.AnnotateTIROpPattern()  # type: ignore
+
+
+def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass:
+    """This pass groups bindings in a dataflow block of Relax functions and 
generate a new grouped
+    Relax function for each group, according to the fusion algorithm described 
in the pass
+    implementation. By grouping bindings into new Relax functions, we 
substitute the bindings in
+    the function being manipulated into function calls to the new grouped 
function.
+
+    A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each 
grouped function.
+
+    Parameters
+    ----------
+    fuse_opt_level : int
+        The level of fuse optimization. -1 indicates that the level will be
+        inferred from pass context.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass for operator fusion.
+    """
+    return _ffi_api.FuseOps(fuse_opt_level)  # type: ignore
+
+
+def FuseTIR() -> tvm.ir.transform.Pass:
+    """Fuse primitive relax function into a larger TIR function if possible
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered pass for tir fusion.
+    """
+    return _ffi_api.FuseTIR()  # type: ignore
+
+
 def _wrap_class_function_pass(pass_cls, pass_info):
     """Wrap a python class as function pass."""
 
diff --git a/src/relax/transform/annotate_tir_op_pattern.cc 
b/src/relax/transform/annotate_tir_op_pattern.cc
new file mode 100644
index 0000000000..b1c1ed29af
--- /dev/null
+++ b/src/relax/transform/annotate_tir_op_pattern.cc
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/annotate_tir_op_pattern.cc
+ * \brief Annotate Op Pattern for TIR functions. It is a pass works on TIR 
PrimFuncs,
+ *        but they are needed for relax fusion. So we put them in the relax 
namespace.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace relax {
+
+tir::PrimFunc AnnotateOpPattern(tir::PrimFunc f) {
+  if (f->HasNonzeroAttr("op_pattern")) {
+    return f;
+  } else {
+    relay::OpPatternKind kind = AnalyzeOpPatternKind(f);
+    return WithAttr(std::move(f), "op_pattern", 
Integer(static_cast<int>(kind)));
+  }
+}
+
+namespace transform {
+
+Pass AnnotateTIROpPattern() {
+  auto pass_func = [=](tir::PrimFunc f, IRModule m, PassContext ctx) {
+    return AnnotateOpPattern(std::move(f));
+  };
+  return tir::transform::CreatePrimFuncPass(pass_func, 0, 
"AnnotateTIROpPattern", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
new file mode 100644
index 0000000000..f3559b72da
--- /dev/null
+++ b/src/relax/transform/fuse_ops.cc
@@ -0,0 +1,909 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/fuse_ops.cc
+ * \brief This file contains a pass which groups bindings in a dataflow block 
of Relax
+ * functions and generate a new grouped Relax function for each group, 
according to the fusion
+ * algorithm described below. By grouping bindings into new Relax functions, 
we substitute the
+ * bindings in the function being manipulated into function calls to the new 
grouped function.
+ *
+ * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each 
grouped function.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/function.h>
+
+#include <optional>
+
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+
+namespace tvm {
+namespace relax {
+
+/*
+  Note on Fusing algorithm:
+
+  The main challenge of general fusor is to handle possible diamond shape 
branches,
+  in the following graph, conv2d can be fused to elemwise add.
+
+            conv2d
+            /  |  \
+           /   |   \
+         op    op   op
+          \    |    /
+           \   |   /
+          elemwise add
+               |
+
+  However, at the point of conv2d we do not necessarily know that all the 
future paths
+  will merge at the elemwise add. The fusion algorithm applies post-dominator 
analysis.
+
+  The immediate post-dominator of a node defined by the closest node where all 
the future path goes
+  into. In the above case, the elemwise add is the post-dominator of conv2d. 
The general algorithm
+  is as follows:
+
+  - Construct a DAG of dataflow graph for dominator analysis
+  - Construct a post-dominator tree which gives immediate post dominator of 
each node.
+  - Run fusion algorithm with the given post-dominator information.
+
+  Note that, because we run analysis on a DAG, we use a single pass 
post-dominator
+  tree construction algorithm via LCA, which is simpler than the full version 
that handles cycles.
+
+  The fusion algorithm traverses from each node and checks if it can be fused 
to its
+  immediate post dominator. It has to check the following things:
+
+  - CheckPath: check all the path between a node and its immediate 
post-dominator
+               satisfies the fuse condition.
+  - Note that these intermediate node can already be fused with another nodes, 
the algorithm
+      will still run correctly.
+  - CommitFuse: mark all the nodes between source and post-dominator as the 
same group.
+  - We use an Union-Find data structure to manage the groups.
+*/
+
+using relay::GraphPartitioner;
+using relay::IndexedForwardGraph;
+using relay::OpPatternKind;
+using support::LinkNode;
+
+constexpr uint32_t kMaxFusedOps = 256;
+
+TVM_REGISTER_PASS_CONFIG_OPTION("relax.FuseOps.max_depth", Integer);
+
+class GraphCreator : public ExprVisitor {
+ public:
+  /*!
+   * \brief Create a IndexedForwardGraph according to the input module. The 
graph will be used for
+   * graph partition and operator fusion.
+   * \param mod The module which the creation accords to
+   * \param arena The allocator of all the internal node objects
+   * \return The created IndexedForwardGraph
+   */
+  static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) {
+    // Since cross-function call is not supported yet, FuseOps only serves the 
entry function, whose
+    // name is "main".
+    auto relax_func = Downcast<Function>(mod->Lookup("main"));
+    GraphCreator creator(mod, arena);
+    creator(relax_func);
+
+    // The algorithm of the graph creator ensures that each created node will 
be added to the
+    // post-dfs order and will be set its op pattern. Thus we check whether 
all these containers
+    // have the same size.
+    size_t n_nodes = creator.graph_.node_map.size();
+    ICHECK_EQ(n_nodes, creator.graph_.post_dfs_order.size());
+    ICHECK_EQ(n_nodes, creator.initialized_nodes_.size());
+
+    return creator.graph_;
+  }
+
+ private:
+  explicit GraphCreator(IRModule mod, support::Arena* arena)
+      : mod_(std::move(mod)), arena_(arena) {}
+
+  void VisitExpr_(const FunctionNode* func) final {
+    for (const Var& param : func->params) {
+      IndexedForwardGraph::Node* param_node = CreateNode(param.get());
+      // The parameter is passed in from the outside, and thus it's marked as 
an external reference,
+      // and it's pattern is `kOpaque`.
+      MarkAsExternRef(param_node);
+      SetNodePattern(param_node, OpPatternKind::kOpaque);
+      AddToPostDFSOrder(param_node, param.get());
+    }
+    ExprVisitor::VisitExpr_(func);
+  }
+
+  void VisitBindingBlock(const BindingBlock& block) final {
+    if (const auto* df_block = block.as<DataflowBlockNode>()) {
+      VisitBindingBlock_(df_block);
+    }
+    // We skip ordinary binding blocks since they might be impure (with side 
effect or control flow)
+  }
+
+  // TODO(tvm-team): how to deal with MatchCast binding here
+
+  void VisitBinding_(const VarBindingNode* binding) final {
+    IndexedForwardGraph::Node* node = CreateNode(binding->var.get());
+
+    // If the variable is not a dataflow variable, it must be the output 
variable of this dataflow
+    // block
+    if (!binding->var->IsInstance<DataflowVarNode>()) {
+      this->MarkAsExternRef(node);
+    }
+    if (const auto* call = binding->value.as<CallNode>()) {
+      // Case 1. The expression is a CallNode
+      VisitCall(call, node);
+    } else if (const auto* tuple_get_item = 
binding->value.as<TupleGetItemNode>()) {
+      // Case 2. The expression is a TupleGetItemNode
+      VisitTupleGetItem(tuple_get_item, node);
+    } else {
+      VisitUnsupportedNode(binding->value, node);
+      // Case 3. The type of the expression is not fusion-supported.
+      // In this case, we skip adding edges, adding an empty node into graph.
+    }
+    AddToPostDFSOrder(node, binding->var.get());
+  }
+
+  /********** Non-Leaf Expression Nodes **********/
+
+  void VisitCall(const CallNode* call, IndexedForwardGraph::Node* 
binding_var_node) {
+    ICHECK_NOTNULL(binding_var_node);
+
+    static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+    OpPatternKind pattern = OpPatternKind::kOpaque;
+    Array<Expr> args = call->args;
+
+    // - If the op being called is a TIR PrimFunc, we get the function op 
pattern directly from the
+    // function attribute and visit the arguments one by one.
+    // - Otherwise, the pattern of the current binding variable node is set to 
`kOpaque`, and we
+    // recurse into the call expression.
+    const auto* op = call->op.as<OpNode>();
+    if (op == call_tir_op_.get()) {
+      const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
+      tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
+
+      // Override args for call_tir
+      args = Downcast<Tuple>(call->args[1])->fields;
+
+      // TODO(tvm-team): handle the shape argument (args[3])
+      Optional<Integer> opt_pattern = func->GetAttr<Integer>("op_pattern");
+      if (opt_pattern.defined()) {
+        pattern = 
static_cast<OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
+      } else {
+        pattern = OpPatternKind::kOpaque;
+      }
+    }
+    // The pattern of the current binding variable node is set to the pattern 
of this operator.
+    SetNodePattern(binding_var_node, pattern);
+    // Visit all call args
+    for (const Expr& arg : args) {
+      ICHECK(IsLeaf(arg));
+      VisitLeaf(arg, binding_var_node, pattern);
+    }
+  }
+
+  void VisitTupleGetItem(const TupleGetItemNode* tuple_item,
+                         IndexedForwardGraph::Node* binding_var_node) {
+    ICHECK_NOTNULL(binding_var_node);
+
+    SetNodePattern(binding_var_node, OpPatternKind::kInjective);
+    VisitLeaf(tuple_item->tuple, binding_var_node, OpPatternKind::kInjective);
+  }
+
+  void VisitUnsupportedNode(const Expr& expr, IndexedForwardGraph::Node* 
binding_var_node) {
+    ICHECK_NOTNULL(binding_var_node);
+    SetNodePattern(binding_var_node, OpPatternKind::kOpaque);
+
+    auto visit_leaves = [this, &binding_var_node](const Expr& e) {
+      if (e->IsInstance<VarNode>() || e->IsInstance<ConstantNode>()) {
+        VisitLeaf(e, binding_var_node, OpPatternKind::kOpaque);
+      }
+    };
+    PostOrderVisit(expr, visit_leaves);
+  }
+
+  /********** Leaf Expression Nodes **********/
+
+  void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node* 
binding_var_node,
+                 const OpPatternKind& pattern) {
+    ICHECK_NOTNULL(binding_var_node);
+
+    // Recursive visit if it's Tuple
+    if (const auto* tuple = leaf_expr.as<TupleNode>()) {
+      for (const Expr& expr : tuple->fields) {
+        VisitLeaf(expr, binding_var_node, pattern);
+      }
+      return;
+    }
+
+    auto it = graph_.node_map.find(leaf_expr.get());
+    IndexedForwardGraph::Node* leaf_node = nullptr;
+    if (it != graph_.node_map.end()) {
+      leaf_node = it->second;
+    } else if (leaf_expr->IsInstance<ConstantNode>()) {
+      leaf_node = CreateNode(leaf_expr.get());
+      // Since we never fuse constants, the pattern of the constant is set to 
`kOpaque`.
+      SetNodePattern(leaf_node, OpPatternKind::kOpaque);
+      AddToPostDFSOrder(leaf_node, leaf_expr.get());
+    } else {
+      LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got: 
" << leaf_expr
+                 << " used before definition.";
+    }
+    AddEdge(leaf_node, binding_var_node, pattern);
+  }
+
+  /********** Helper Functions **********/
+
+  /*!
+   * \brief Check whether the expression is a leaf expression
+   * \param expr The expression to be checked
+   * \return Whether the expression is a leaf expression
+   * \note In order to avoid too much refactor, this method is a simple 
copy-paste of the is-leaf
+   * check in "block_builder.cc". And it should be refactored in the future.
+   * \sa src/relax/ir/block_builder.cc
+   */
+  static bool IsLeaf(const Expr& expr) {
+    // NOTE: Tuples are treated as leaf nodes for ergonomics
+    return expr.as<VarNode>() || expr.as<GlobalVarNode>() || 
expr.as<ConstantNode>() ||
+           expr.as<ShapeExprNode>() || expr.as<ExternFuncNode>() || 
expr.as<OpNode>() ||
+           expr.as<TupleNode>();
+  }
+
+  /*!
+   * \brief Create a graph node corresponding to the input key
+   * \param key The object which is used to create the graph node
+   * \return The created graph node
+   * \note The node corresponding to each key is supposed to be created for 
only once
+   */
+  IndexedForwardGraph::Node* CreateNode(const Object* key) {
+    ICHECK(graph_.node_map.find(key) == graph_.node_map.end())
+        << "The node corresponding to the input key is not supposed to be 
created before";
+    auto* node = arena_->make<IndexedForwardGraph::Node>();
+    graph_.node_map[key] = node;
+    return node;
+  }
+
+  /*!
+   * \brief Append the input node to the post-dfs order of the graph
+   * \param node The node to be appended
+   * \param key The key corresponding to the node
+   * \note Each node is supposed to be appended to the post-dfs order for only 
once
+   */
+  void AddToPostDFSOrder(IndexedForwardGraph::Node* node, const Object* key) {
+    auto it = graph_.node_map.find(key);
+    ICHECK(it != graph_.node_map.end() && it->second == node)
+        << "The node must have been created before adding to the post-dfs 
order";
+
+    // We only set the reference of the node when adding it to the post-dfs 
order. Thus, if the
+    // reference of a node is already set, it must have been appended to the 
post-dfs order.
+    ICHECK(node->ref == nullptr)
+        << "The node is not supposed to be added into the post-dfs order 
before";
+
+    node->ref = key;
+    node->index = graph_.post_dfs_order.size();
+    graph_.post_dfs_order.push_back(node);
+  }
+
+  /*!
+   * \brief Add an edge from the input start to the input end in the graph, 
with specific pattern
+   * \param start The start of the edge
+   * \param end The end of the edge
+   * \param pattern The pattern of this edge
+   */
+  void AddEdge(IndexedForwardGraph::Node* start, IndexedForwardGraph::Node* 
end,
+               OpPatternKind pattern) {
+    auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge>>();
+    link->value.node = end;
+    link->value.pattern = pattern;
+    start->outputs.Push(link);
+  }
+
+  /*!
+   * \brief Mark a given node as "external reference", which means the node 
cannot be fused as an
+   * intermediate node
+   * \param node The graph node to be marked
+   */
+  void MarkAsExternRef(IndexedForwardGraph::Node* node) { node->extern_ref = 
true; }
+
+  /*!
+   * \brief Set the pattern of the input node
+   * \param node The graph node to be set
+   * \param pattern The pattern of the node
+   */
+  void SetNodePattern(IndexedForwardGraph::Node* node, OpPatternKind pattern) {
+    ICHECK(initialized_nodes_.find(node) == initialized_nodes_.end())
+        << "The input node is supposed to be set pattern for only once";
+    initialized_nodes_.insert(node);
+    node->pattern = pattern;
+  }
+
+ private:
+  /*! \brief The IRModule from which the indexed forward graph is created */
+  IRModule mod_;
+  /*! \brief The allocator of all the internal node objects */
+  support::Arena* arena_;
+  /*! \brief The created indexed forward graph */
+  IndexedForwardGraph graph_;
+  /*! \brief The graph nodes whose patterns are set */
+  std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
+};
+
+/*!
+ * \brief The ExprMutator used to create a new grouped function
+ * \details The workflow of this ExprMutator is:
+ *  - The bindings in the function will be added by OperatorFusor via 
`AppendBinding(...)`.
+ *  - When adding a new binding through `AppendBinding(...)`, we check whether 
the variables and
+ *  constants used by the binding are defined by some previous added binding. 
And for the undefined
+ *  variables and constants, we add them to the argument list and created new 
variables as the
+ *  corresponding parameters.
+ *  - When `CreateFunction()` is called, we go through each binding and update 
the binding with the
+ *  new parameters. After that we wrap all bindings with a DataflowBlock and a 
Function.
+ */
+class FunctionCreator : public ExprMutator {
+ public:
+  explicit FunctionCreator(bool lift_constant) : lift_constant_(lift_constant) 
{}
+  /*!
+   * \brief Append a new binding to this function and possibly create new 
parameters for the
+   * function accordingly
+   * \param binding The binding to be appended
+   * \note Allowed bindings are:
+   *  - VarBinding with value being a call node calling `relax.call_tir`.
+   *  - VarBinding with value being a tuple-get-item node.
+   * // TODO(tvm-team): handle match shape
+   */
+  void AppendBinding(const Binding& binding) {
+    ICHECK(!function_.defined())
+        << "The `function_` is supposed to be uncreated when adding bindings";
+
+    if (const auto* var_binding = binding.as<VarBindingNode>()) {
+      if (const auto* call = var_binding->value.as<CallNode>()) {
+        if (call->op == Op::Get("relax.call_tir")) {
+          // Update the name of the function.
+          name_hint_ = name_hint_ + "_" + 
Downcast<GlobalVar>(call->args[0])->name_hint;
+
+          const Tuple& args = Downcast<Tuple>(call->args[1]);
+          for (const Expr& arg : args->fields) {
+            CheckDefAndUpdateParam(arg);
+          }
+          // TODO(tvm-team): handle shape expr
+        } else {
+          if (call->op->IsInstance<OpNode>()) {
+            name_hint_ = name_hint_ + "_" + Downcast<Op>(call->op)->name;
+          } else if (call->op->IsInstance<GlobalVarNode>()) {
+            std::string gvar_name = Downcast<GlobalVar>(call->op)->name_hint;
+            if (auto pos = gvar_name.find("fused_"); pos == 0) {
+              name_hint_ = name_hint_ + "_" + 
gvar_name.substr(std::string("fused_").size());
+            } else {
+              name_hint_ = name_hint_ + "_" + gvar_name;
+            }
+          }
+
+          for (const Expr& arg : call->args) {
+            CheckDefAndUpdateParam(arg);
+          }
+        }
+      } else {
+        const auto* tuple_item = var_binding->value.as<TupleGetItemNode>();
+        ICHECK(tuple_item != nullptr);
+        CheckDefAndUpdateParam(tuple_item->tuple);
+      }
+
+      // Mark the binding variable as defined.
+      defined_vars_.insert(var_binding->var.get());
+      // Set var as output true if the binding is not a dataflow variable
+      if (!var_binding->var->IsInstance<DataflowVarNode>()) {
+        AppendOutput(var_binding->var);
+      }
+    } else {
+      // TODO(tvm-team): handle match_cast
+    }
+    bindings_.push_back(binding);
+  }
+
+  /*! \brief Set a var defined in the group as output. */
+  size_t AppendOutput(const Var& var) {
+    ICHECK(defined_vars_.count(var.get()));
+    auto output_idx = GetOutputIndex(var);
+    if (output_idx) {
+      return *output_idx;
+    }
+    output_vars_.push_back(var.get());
+    return output_vars_.size() - 1;
+  }
+
+  /*!
+   * \brief Create the grouped function according according to the collected 
bindings and parameters
+   * \param composite_name The name to identify the pattern this function is 
created from, if any.
+   * It will become the value of the kComposite attribute of the created 
function.
+   * \note The created function won't be returned immediately. It's stored in 
the `function_` field.
+   */
+  void CreateFunction(Map<String, ObjectRef> group_attrs) {
+    // Step 1. Start constructing a new dataflow block.
+    builder_->BeginDataflowBlock();
+
+    // Step 2. Visit each binding and collect outputs one by one.
+    Array<Expr> outputs(output_vars_.size(), Expr());
+    for (const Binding& binding : bindings_) {
+      if (auto output_idx = GetOutputIndex(binding->var)) {
+        // Case 1. It is an output binding
+        // We only allow VarBinding as output.
+        const auto* var_binding = binding.as<VarBindingNode>();
+        ICHECK_NOTNULL(var_binding);
+        Var output_var = builder_->EmitOutput(VisitExpr(var_binding->value));
+        var_remap_[var_binding->var->vid] = output_var;
+        outputs.Set(*output_idx, output_var);
+      } else {
+        // Case 2. It is an internel binding, add it to the binding list.
+        VisitBinding(binding);
+      }
+    }
+
+    // Step 3. Finish constructing the new block.
+    BindingBlock new_block = builder_->EndBlock();
+    ICHECK(!outputs.empty()) << "At least one output is required.";
+    Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs);
+    body = builder_->Normalize(body);
+    body = builder_->Normalize(SeqExpr({new_block}, body));
+    group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
+    function_ = Function(/*params=*/params_,           //
+                         /*body=*/body,                //
+                         /*ret_struct_info=*/NullOpt,  //
+                         /*attrs=*/DictAttrs(group_attrs));
+  }
+
+  /*! \brief The original bindings of the function */
+  Array<Binding> bindings_;
+  /*! \brief The parameters of the function */
+  Array<Var> params_;
+  /*! \brief The arguments to call the function on the caller side */
+  Array<Expr> arguments_;
+  /*! \brief The name for the fused function */
+  String name_hint_ = "fused";
+  /*! \brief The constructed Relax function */
+  Function function_{nullptr};
+
+ private:
+  std::optional<size_t> GetOutputIndex(Var v) {
+    auto it = std::find(output_vars_.begin(), output_vars_.end(), v.get());
+    if (it != output_vars_.end()) {
+      return std::distance(output_vars_.begin(), it);
+    }
+    return std::nullopt;
+  }
+
+  /*!
+   * \brief Check whether the input expression is defined within this 
function. If not, create a new
+   * parameter for the expression.
+   * \param expr The expression to be checked
+   */
+  void CheckDefAndUpdateParam(const Expr& expr) {
+    // If the expression has already served as an argument, no need to create 
another one for it.
+    if (std::find(arguments_.begin(), arguments_.end(), expr) != 
arguments_.end()) {
+      return;
+    }
+
+    // If the expression is not a variable or is a undefined variable, it 
should be populated as a
+    // parameter of the relax function.
+    const auto* var = expr.as<VarNode>();
+    if ((var == nullptr || defined_vars_.count(var) == 0) &&
+        (lift_constant_ || !expr->IsInstance<ConstantNode>())) {
+      String name{nullptr};
+      if (var != nullptr) {
+        name = var->name_hint();
+      } else {
+        name = String("param_" + std::to_string(n_param_for_const_++));
+      }
+
+      Var param(std::move(name), GetStructInfo(expr));
+      arguments_.push_back(expr);
+      params_.push_back(param);
+    }
+  }
+
+  Expr VisitExpr(const Expr& expr) final {
+    // If the expression serves as an argument, return its correspondng 
parameter.
+    auto it = std::find(arguments_.begin(), arguments_.end(), expr);
+    if (it != arguments_.end()) {
+      return params_[it - arguments_.begin()];
+    }
+    // Otherwise, recurse into this expression.
+    return ExprMutator::VisitExpr(expr);
+  }
+
+ private:
+  /*! \brief The variables defined in this function */
+  std::unordered_set<const VarNode*> defined_vars_;
+  /*! \brief The number of parameters reserved for constants */
+  int n_param_for_const_ = 0;
+  /*! \brief The output vars */
+  std::vector<const VarNode*> output_vars_;
+  /*! \brief Whether or not to lift bound constants to parameters */
+  bool lift_constant_;
+};
+
+/*!
+ * \brief The ExprMutator used to fuse the operators in Relax functions
+ * \details Given the partition results on the indexed-forward graph, for each 
group whose size is
+ * larger than one, we create a new grouped function for it, containing all 
bindings in that group.
+ * And we substitute the bindings in a group with a single function call to 
the newly created
+ * grouped function. The workflow of this ExprMutator is: for each dataflow 
block,
+ *   - we go through the bindings one by one. For each binding, if it is in a 
group whose size is
+ *   larger than one, we add the binding to the function of the group it is in 
and update the
+ *   parameters and arguments of that function;
+ *   - then we finalize all the grouped functions by updating their bindings 
using BlockBuilder;
+ *   - lastly, we go through the bindings again and substitute the bindings in 
a group with a single
+ *   call to the corresponding grouped function.
+ *
+ * After transforming a Relax function, we update the function in the 
IRModule. Besides, we add all
+ * newly created grouped function to the IRModule.
+ */
+class OperatorFusor : public ExprMutator {
+ public:
+  using Group = GraphPartitioner::Group;
+  using GroupMap = std::unordered_map<const Object*, Group*>;
+
+  OperatorFusor(IRModule mod, const GroupMap& obj2group, bool lift_constants = 
true)
+      : ExprMutator(mod),
+        mod_(std::move(mod)),
+        obj2group_(obj2group),
+        lift_constants_(lift_constants) {}
+
+  /*!
+   * \brief Construct a new operator fusor. Given the indexed-forward graph 
and the graph partition
+   * result on that graph, the constructor creates a mapping from each leaf 
AST object
+   * (e.g. parameters, variables, constants) to the group of the node 
corresponding to the object
+   * in the graph.
+   * \param mod The IRModule to be transformed
+   * \param graph The indexed-forward graph of the input IRModule
+   * \param groups The grouped result of the group partition on the input 
indexed-forward graph.
+   */
+  OperatorFusor(IRModule mod, const IndexedForwardGraph& graph, const 
std::vector<Group*>& groups,
+                bool lift_constant = true)
+      : OperatorFusor(mod, CreateGroupMap(graph, groups), lift_constant) {}
+
+  /*!
+   * \brief The main transformation on the IRModule
+   * \return The new IRModule after transformation
+   */
+  IRModule Transform() {
+    for (const auto& [gv, func] : mod_->functions) {
+      // Only visit Relax function without attr kPrimitive.
+      if (func->IsInstance<relax::FunctionNode>() && 
!func->HasNonzeroAttr(attr::kPrimitive)) {
+        auto updated_func = Downcast<Function>(VisitExpr(func));
+        builder_->UpdateFunction(gv, updated_func);
+      }
+    }
+    return builder_->GetContextIRModule();
+  }
+
+ private:
+  static GroupMap CreateGroupMap(const IndexedForwardGraph& graph,
+                                 const std::vector<Group*>& groups) {
+    GroupMap obj2group;
+    for (int nid = 0; nid < static_cast<int>(graph.post_dfs_order.size()); 
++nid) {
+      Group* group_root = groups[nid]->FindRoot();
+      ICHECK(group_root != nullptr);
+      ICHECK(graph.post_dfs_order[nid]->ref != nullptr);
+      obj2group[graph.post_dfs_order[nid]->ref] = group_root;
+    }
+    return obj2group;
+  }
+
+  bool IsTupleOutput(Function f) {
+    auto sinfo = GetStructInfo(f).as<FuncStructInfoNode>();
+    ICHECK(sinfo);
+    return sinfo->ret->IsInstance<TupleStructInfoNode>();
+  }
+
+  BindingBlock VisitBindingBlock(const BindingBlock& block) final {
+    if (const auto* df_block = block.as<DataflowBlockNode>()) {
+      return VisitBindingBlock_(df_block);
+    }
+    // We skip ordinary binding blocks since they might be impure (with side 
effect or control flow)
+    return block;
+  }
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
+    group2func_.clear();
+
+    // Step 1. Collect the bindings for each grouped function.
+    CollectFuncBindings(block->bindings);
+
+    // Step 2. Collect all group's boundary (i.e. the output vars for each 
group)
+    CollectFuncBoundary(block->bindings);
+
+    // Step 3. Create the grouped function for each group.
+    for (auto& [g, creator] : group2func_) {
+      creator.CreateFunction(g->attrs);
+    }
+
+    // Step 4. Start generating the new binding block.
+    //  - For groups with single binding, we directly recurse into the binding 
and emit the new one.
+    //  - For groups with multiple bindings, we emit the call to the grouped 
function only when
+    //  visiting the last binding of the group, because only by doing this we 
don't break the
+    //  dependencies among the bindings of different groups. And therefore, we 
will skip all but the
+    //  last binding of the group.
+    builder_->BeginDataflowBlock();
+
+    // For each group, record which variables need to be remapped to the 
output of TupleGetItem.
+    // Only relevant when the output of the grouped function is a tuple.
+    std::unordered_map<Group*, std::vector<Var>> pending_tuple_get;
+
+    // A grouped function which returns a tuple requires attaching 
TupleGetItem to each element and
+    // remapping variables in earlier bindings approriately. Thus, a binding 
whose value depends on
+    // some elements of a tuple from other group's function must be emitted 
after a call to the
+    // tuple-producing function is emitted and remapping is done.
+    // To guarantee this, we process bindings in the order of the topological 
sort of the group
+    // dependency relations.
+    for (const auto& binding : TopoSortByGroupDep(block->bindings)) {
+      // Case 1. If the binding is the only binding in its group, recurse into 
it and emit the
+      // transformed binding as usual.
+      Group* group = GetGroupFromBinding(binding);
+      if (group->num_nodes == 1 && group->attrs.empty()) {
+        VisitBinding(binding);
+        continue;
+      }
+
+      const auto& it_creator = group2func_.find(group);
+      ICHECK(it_creator != group2func_.end());
+      const FunctionCreator& func_info = it_creator->second;
+
+      // If this binding belongs to a group whose output is a tuple, the 
original bound variable
+      // needs to be remapped to the output of TupleGetItem after the 
corresponding tuple is
+      // emitted.
+      if (IsTupleOutput(func_info.function_) && 
tuple_get_indices_.count(binding->var.get())) {
+        pending_tuple_get[group].push_back(binding->var);
+      }
+
+      // Case 2. If the binding is not the last binding of the group, we skip 
it.
+      if (!func_info.bindings_.back().same_as(binding)) {
+        continue;
+      }
+
+      // Case 3. The binding is the last binding of the group.
+      const auto* var_binding = binding.as<VarBindingNode>();
+      ICHECK(var_binding != nullptr) << "The last binding of a group whose 
size is larger than 1 "
+                                        "is supposed to be a variable binding";
+
+      // Step a. Add the grouped function to the IRModule
+      GlobalVar gv = builder_->AddFunction(func_info.function_, 
func_info.name_hint_);
+
+      // Step b. Create the call to the deduplicated function, and then emit 
the call.
+      //  - If this binding is an output binding, emit an output variable.
+      //  - Otherwise, emit a dataflow variable.
+      Var new_var;
+      Call call_to_emit = Call(gv, UpdateArgs(func_info.arguments_));
+
+      if (var_binding->var->IsInstance<DataflowVarNode>()) {
+        new_var = builder_->Emit(call_to_emit);
+      } else {
+        new_var = builder_->EmitOutput(call_to_emit);
+      }
+
+      // Step c. Update the mapping used for the remapping of the binding 
variables.
+      if (IsTupleOutput(func_info.function_)) {
+        // If the output is a tuple, attach TupleGetItem to all tuple 
elements, and
+        // remap variables approriately.
+        // The variables that need to be remapped and the corresponding tuple 
indices are
+        // available in pending_tuple_get and tuple_get_indices_ respectively.
+        for (const auto& var : pending_tuple_get[group]) {
+          auto tuple_get = TupleGetItem(new_var, 
tuple_get_indices_[var.get()]);
+          var_remap_[var->vid] = builder_->Emit(tuple_get);
+        }
+      } else {
+        var_remap_[var_binding->var->vid] = new_var;
+      }
+    }
+    // Step 5. Finish the binding block generation.
+    return builder_->EndBlock();
+  }
+
+  /*!
+   * \brief Collect the bindings for each grouped function and update the 
information of the grouped
+   * function
+   * \param bindings The bindings to be collected
+   * \note The function update is done by `AppendBinding(...)`
+   */
+  void CollectFuncBindings(const Array<Binding>& bindings) {
+    for (const Binding& binding : bindings) {
+      // If the binding is the only binding in its group, there is no need to 
create a new function.
+      Group* group = GetGroupFromBinding(binding);
+      if (group->num_nodes == 1 && group->attrs.empty()) {
+        continue;
+      }
+      // Add the binding to the grouped function it's in, and update the 
function information
+      // accordingly.
+      if (!group2func_.count(group)) {
+        group2func_.emplace(group, lift_constants_);
+      }
+      group2func_.find(group)->second.AppendBinding(binding);
+    }
+  }
+
+  void CollectFuncBoundary(const Array<Binding>& bindings) {
+    for (const Binding& binding : bindings) {
+      // Step 1. Get current binding's group
+      Group* cur_group = GetGroupFromBinding(binding);
+
+      // Step 2. Collect all used vars in the binding value and update bondary.
+      // - If the var's group is same as the binding's, the var is defined in 
the same group
+      // - If the var's group is different with the binding's, the var must be 
the output from
+      //   another group. Mark it to be the group output.
+      auto update_boundary = [this, binding, &cur_group](const Expr& e) {
+        if (e->IsInstance<VarNode>()) {
+          const Var& used_var = Downcast<Var>(e);
+          Group* producer_group = GetGroupFromVar(used_var);
+          // Only check those group defined before.
+          // Skip the vars from input or groups with single binding.
+          if (producer_group != cur_group) {
+            ICHECK(!group_deps_[producer_group].count(cur_group))
+                << "A cyclic dependency detected between the groups " << 
binding->var->name_hint()
+                << " and " << used_var->name_hint() << " are in.";
+            group_deps_[cur_group].insert(producer_group);
+          }
+
+          if (auto producer = group2func_.find(producer_group);
+              producer_group != cur_group && producer != group2func_.end()) {
+            auto output_index = producer->second.AppendOutput(used_var);
+            tuple_get_indices_[used_var.get()] = output_index;
+          }
+        }
+      };
+
+      if (const auto* var_binding = binding.as<VarBindingNode>()) {
+        PostOrderVisit(var_binding->value, update_boundary);
+      } else {
+        const auto* match_cast = binding.as<MatchCastNode>();
+        ICHECK_NOTNULL(match_cast);
+        PostOrderVisit(match_cast->value, update_boundary);
+      }
+    }
+  }
+
+  /*!
+   * \brief Get the group which the input binding is in
+   * \param binding The binding to be queried
+   * \return The pointer to the group which the input binding is in
+   */
+  Group* GetGroupFromBinding(const Binding& binding) {
+    Var var = binding->var;
+    return GetGroupFromVar(var);
+  }
+
+  /*!
+   * \brief Get the group which the input var is in
+   * \param Var The var to be queried
+   * \return The pointer to the group which the input var is in
+   */
+  Group* GetGroupFromVar(const Var& var) {
+    const auto& it_group = obj2group_.find(var.get());
+    ICHECK(it_group != obj2group_.end());
+    Group* group = it_group->second;
+    return group->FindRoot();
+  }
+
+  /*!
+   * \brief Update the pre-stored arguments according to the variable 
remapping of the fusor, by
+   * recursing into each argument
+   * \param args The arguments to be updated
+   * \return The updated arguments
+   */
+  Array<Expr> UpdateArgs(const Array<Expr>& args) {
+    Array<Expr> new_args;
+    new_args.reserve(args.size());
+    for (const Expr& arg : args) {
+      new_args.push_back(VisitExpr(arg));
+    }
+    return new_args;
+  }
+
+ private:
+  // Topologically sort bindings according to the group dependency relations.
+  Array<Binding> TopoSortByGroupDep(const Array<Binding>& bindings) {
+    std::unordered_map<Group*, std::vector<Binding>> bindings_per_group;
+    // The order to visit groups should respect the original order of bindings 
as much as possible.
+    std::vector<Group*> group_order;
+    for (const auto& binding : bindings) {
+      auto g = GetGroupFromBinding(binding);
+      group_order.push_back(g);  // Duplication does not matter since each 
group is visited once.
+      bindings_per_group[g].push_back(binding);
+    }
+
+    std::unordered_set<Group*> visited;
+
+    std::function<void(Group*, std::function<void(Group*)>)> dfs_visit;
+    dfs_visit = [this, &visited, &dfs_visit](Group* g, auto leaf_fun) {
+      if (!visited.count(g)) {
+        visited.insert(g);
+        for (auto dep : group_deps_[g]) {
+          dfs_visit(dep, leaf_fun);
+        }
+        leaf_fun(g);
+      }
+    };
+
+    Array<Binding> sorted;
+
+    for (auto g : group_order) {
+      dfs_visit(g, [&sorted, &bindings_per_group](Group* leaf) {
+        for (const auto& binding : bindings_per_group[leaf]) {
+          sorted.push_back(binding);
+        }
+      });
+    }
+
+    return sorted;
+  }
+
+  /*! \brief The IRModule. */
+  IRModule mod_;
+  /*! \brief Internal arena. */
+  support::Arena arena_;
+  /*! \brief The group assignment map. */
+  GroupMap obj2group_;
+  /*! \brief Internal function information map. */
+  std::unordered_map<Group*, FunctionCreator> group2func_;
+  /*! \brief Record the index for TupleGetItem if the variable needs to be 
remapped to an output
+   * tuple element after fusion. */
+  std::unordered_map<const VarNode*, int> tuple_get_indices_;
+  /*! \brief A map from a group to its dependent groups, used to detect cyclic 
dependencies. */
+  std::unordered_map<Group*, std::unordered_set<Group*>> group_deps_;
+  /*! \brief Whether or not to lift bound constants to parameters of the 
grouped function. */
+  bool lift_constants_{true};
+};
+
+IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) {
+  support::Arena arena;
+
+  // Step 1. Create the indexed-forward graph according to the input IRModule.
+  IndexedForwardGraph graph = GraphCreator::Create(mod, &arena);
+
+  // Step 2. Partition the graph by applying the fusion algorithm.
+  std::vector<GraphPartitioner::Group*> groups =
+      GraphPartitioner(&arena, opt_level, max_fuse_depth).Partition(graph);
+
+  // Step 3. Transform the IRModule by fusing the operators in accordance with 
the graph partition
+  // results.
+  return OperatorFusor(mod, graph, groups, /*lift_constants*/ 
true).Transform();
+}
+
+namespace transform {
+
+Pass FuseOps(int fuse_opt_level) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
+      [=](IRModule m, PassContext pc) {
+        int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
+        auto max_fuse_depth = pc->GetConfig("relax.FuseOps.max_depth", 
Integer(kMaxFusedOps));
+        return relax::FuseOps(m, opt_level, max_fuse_depth.value().IntValue());
+      };
+  return CreateModulePass(/*pass_function=*/pass_func,  //
+                          /*opt_level=*/0,              //
+                          /*pass_name=*/"FuseOps",      //
+                          /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
new file mode 100644
index 0000000000..fa5c296d27
--- /dev/null
+++ b/src/relax/transform/fuse_tir.cc
@@ -0,0 +1,728 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+// TODO(Siyuan): move it to somewhere under tir folder
+/*!
+ * \brief Substitute a given source buffer with a given target buffer in 
statements or expressions.
+ */
+class FuseTIRBufferSubstitor : private StmtExprMutator {
+ public:
+  static Stmt Substitute(const Map<Buffer, Buffer>& buffer_map, Stmt stmt) {
+    return FuseTIRBufferSubstitor(buffer_map)(std::move(stmt));
+  }
+
+ private:
+  explicit FuseTIRBufferSubstitor(const Map<Buffer, Buffer>& buffer_map) {
+    for (const auto& kv : buffer_map) {
+      const Buffer& src = kv.first;
+      const Buffer& tgt = kv.second;
+      buffer_var_map_[src->data.get()] = tgt;
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* _op) final {
+    auto it = buffer_var_map_.find(_op);
+    if (it != buffer_var_map_.end()) {
+      return it->second->data;
+    } else {
+      return GetRef<PrimExpr>(_op);
+    }
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* _op) final {
+    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_op));
+    auto it = buffer_var_map_.find(load->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      auto n = make_object<BufferLoadNode>(*load.get());
+      n->buffer = it->second;
+      return BufferLoad(n);
+    } else {
+      return std::move(load);
+    }
+  }
+
+  Stmt VisitStmt_(const BufferStoreNode* _op) final {
+    BufferStore store = 
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_op));
+    auto it = buffer_var_map_.find(store->buffer->data.get());
+    if (it != buffer_var_map_.end()) {
+      auto n = CopyOnWrite(store.get());
+      n->buffer = it->second;
+      return BufferStore(n);
+    } else {
+      return std::move(store);
+    }
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* _op) final {
+    Load load = Downcast<Load>(StmtExprMutator::VisitExpr_(_op));
+    auto it = buffer_var_map_.find(load->buffer_var.get());
+    if (it != buffer_var_map_.end()) {
+      auto n = make_object<LoadNode>(*load.get());
+      n->buffer_var = it->second->data;
+      return Load(n);
+    } else {
+      return std::move(load);
+    }
+  }
+
+  Stmt VisitStmt_(const StoreNode* _op) final {
+    Store store = Downcast<Store>(StmtExprMutator::VisitStmt_(_op));
+    auto it = buffer_var_map_.find(store->buffer_var.get());
+    if (it != buffer_var_map_.end()) {
+      auto n = CopyOnWrite(store.get());
+      n->buffer_var = it->second->data;
+      return Store(n);
+    } else {
+      return std::move(store);
+    }
+  }
+
+  Stmt VisitStmt_(const BlockNode* _op) final {
+    Block block = Downcast<Block>(StmtMutator::VisitStmt_(_op));
+
+    // Define the mutation functions.
+    auto f_mutate_match_buffers = [this](const MatchBufferRegion& 
match_buffer) {
+      const Buffer& src_buffer = match_buffer->source->buffer;
+      auto it = buffer_var_map_.find(src_buffer->data.get());
+      if (it != buffer_var_map_.end()) {
+        return MatchBufferRegion(match_buffer->buffer,
+                                 BufferRegion(it->second, 
match_buffer->source->region));
+      } else {
+        return match_buffer;
+      }
+    };
+
+    auto f_mutate_read_write_region = [this](const BufferRegion& 
buffer_region) {
+      auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
+      return it == buffer_var_map_.end() ? buffer_region
+                                         : BufferRegion(it->second, 
buffer_region->region);
+    };
+
+    // Step 1. Mutate `match_buffers`.
+    Array<MatchBufferRegion> match_buffers =
+        MutateArray(block->match_buffers, f_mutate_match_buffers);
+    // Step 2. Mutate the read/write region.
+    Array<BufferRegion> reads = MutateArray(block->reads, 
f_mutate_read_write_region);
+    Array<BufferRegion> writes = MutateArray(block->writes, 
f_mutate_read_write_region);
+
+    reads = UnionAccessRegion(reads);
+    writes = UnionAccessRegion(writes);
+
+    if (reads.same_as(block->reads) &&    //
+        writes.same_as(block->writes) &&  //
+        match_buffers.same_as(block->match_buffers)) {
+      return std::move(block);
+    } else {
+      auto n = CopyOnWrite(block.get());
+      n->reads = std::move(reads);
+      n->writes = std::move(writes);
+      n->match_buffers = std::move(match_buffers);
+      return Block(n);
+    }
+  }
+
+ private:
+  /*! \brief Mapping from src buffer.data to tgt buffer. */
+  std::unordered_map<const tir::VarNode*, tir::Buffer> buffer_var_map_;
+  /*! \brief The structural equality checker */
+  StructuralEqual structural_equal_;
+
+  Array<tir::BufferRegion> UnionAccessRegion(const Array<BufferRegion>& 
regions) const {
+    // For now we only allow Buffer access the same elements.
+    // e.g. `[A[vi, vj], A[vi, vj]]` is a legal pattern but need to union to 
`A[vi, vj]`
+    // However, `A[vi, vj], A[vi, vj + 1]` is not allow for now.
+    // Note: the order of return region should remain the same as the first 
occurance of the region
+    Array<BufferRegion> ret;
+    std::unordered_map<const BufferNode*, Region> buffer_region_set;
+
+    for (const BufferRegion& region : regions) {
+      auto it = buffer_region_set.find(region->buffer.get());
+      if (it == buffer_region_set.end()) {
+        ret.push_back(region);
+        buffer_region_set[region->buffer.get()] = region->region;
+      } else {
+        ICHECK(structural_equal_(region->region, it->second));
+      }
+    }
+
+    if (ret.size() == regions.size()) {
+      return regions;
+    } else {
+      return ret;
+    }
+  }
+};
+
+/*! \brief A mutator which detect block name duplication and deduplicate the 
names. */
+class BlockNameDeduplicator : public tir::StmtMutator {
+ private:
+  Stmt VisitStmt_(const BlockNode* op) final {
+    Block block = Downcast<Block>(tir::StmtMutator::VisitStmt_(op));
+
+    String name = GetUniqueName(block->name_hint);
+
+    if (name == block->name_hint) {
+      return std::move(block);
+    } else {
+      ObjectPtr<BlockNode> n = CopyOnWrite(block.get());
+      n->name_hint = std::move(name);
+      return Stmt(n);
+    }
+  }
+
+  String GetUniqueName(const String& prefix) {
+    String unique_prefix = prefix;
+    auto it = name_count_.find(prefix);
+    while (name_count_.count(unique_prefix)) {
+      unique_prefix = prefix + "_" + std::to_string(++it->second);
+    }
+    name_count_[unique_prefix] = 0;
+    return unique_prefix;
+  }
+
+  // TODO(relax-team): It should detects the number suffix and do renaming 
properly
+  // e.g. GetUniqueName("name1") should return "name2" instead of "name10".
+  /*! \brief The count map to make block name unique. */
+  std::unordered_map<String, int> name_count_;
+};
+
+}  // namespace tir
+
+namespace relax {
+
+class FusedTIRConstructor : public ExprVisitor {
+ public:
+  /*!
+   * \brief Construct a fused TIR PrimFunc from a relax sub-function
+   * \param mod The IRModule
+   * \param gv The global var of relax subfunction to be fused into one 
PrimFunc
+   * \return The fused TIR PrimFunc
+   */
+  static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) {
+    FusedTIRConstructor visitor(mod, gv->name_hint);
+    BaseFunc f = mod->Lookup(gv);
+    CHECK(f->IsInstance<relax::FunctionNode>())
+        << "Expected relax functions, but got: " << f->GetTypeKey();
+    CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive))
+        << "Expected a function with attr `kPrimitive`";
+    visitor(Downcast<relax::Function>(f));
+    return visitor.fused_tir_;
+  }
+
+ private:
+  explicit FusedTIRConstructor(const IRModule& mod, const String& func_name)
+      : mod_(mod), func_name_(func_name) {}
+
+  void VisitExpr_(const FunctionNode* func) final {
+    // Step 1. Create buffers for function params
+    for (const Var& relax_param : func->params) {
+      auto ret = CreateParamsAndBuffers(GetStructInfo(relax_param),  //
+                                        relax_param->name_hint());
+      const Array<tir::Var>& params = ret.first;
+      const Array<tir::Buffer>& buffers = ret.second;
+      ICHECK_EQ(params.size(), buffers.size());
+      for (size_t i = 0; i < params.size(); ++i) {
+        func_info_.buffer_map.Set(params[i], buffers[i]);
+        func_info_.params.push_back(params[i]);
+      }
+      func_info_.expr2buffers.Set(relax_param, buffers);
+    }
+
+    // Step 2. Visit Function body and create intermediate buffers
+    ExprVisitor::VisitExpr_(func);
+
+    // Step 3. Create and remap buffers for function output
+    ICHECK(func->body->IsInstance<SeqExprNode>())
+        << "Function body is expected to be a SeqExpr, but got: " << 
func->body->GetTypeKey();
+    Expr body = Downcast<SeqExpr>(func->body)->body;
+    auto it = func_info_.expr2buffers.find(body);
+    ICHECK(it != func_info_.expr2buffers.end())
+        << "Fail to detect output buffers for function body";
+    const Array<tir::Buffer>& buffers = (*it).second;
+    for (size_t i = 0; i < buffers.size(); ++i) {
+      tir::Var param = tir::Var("p_output" + std::to_string(i), 
PrimType(DataType::Handle()));
+      func_info_.buffer_map.Set(param, buffers[i]);
+      func_info_.params.push_back(param);
+      func_info_.output_buffers.insert(buffers[i].get());
+    }
+
+    // Step 4. Create PrimFunc
+    fused_tir_ = ConstructFunc();
+  }
+
+  void VisitBinding_(const VarBindingNode* binding) final {
+    // Update expr2buffers by visiting values.
+    this->VisitExpr(binding->value);
+    auto it = func_info_.expr2buffers.find(binding->value);
+    if (it != func_info_.expr2buffers.end()) {
+      // assign binding var to the buffers of the value
+      func_info_.expr2buffers.Set(binding->var, (*it).second);
+    } else {
+      LOG(FATAL) << "Unsupported binding value: " << binding->value;
+    }
+  }
+
+  void VisitBinding_(const MatchCastNode* match_cast) final {
+    LOG(FATAL) << "MatchCast is unsupported in primitive functions";
+  }
+
+  void VisitExpr_(const CallNode* call) final {
+    ExprVisitor::VisitExpr_(call);
+    static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+    ICHECK(call->op == call_tir_op_)
+        << "Only call_tir is supported in primitive function, but got: " << 
GetRef<Expr>(call);
+
+    // Step 1. Get Global var and PrimFunc
+    GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+    Optional<tir::PrimFunc> prim_func_ = GetPrimFunc(gv);
+    ICHECK(prim_func_.defined()) << "Cannot find the prim_func of the call_tir 
in the module: "
+                                 << gv;
+    // Step 2. Renew all vars/buffer definitions and blocks to avoid 
duplication
+    tir::PrimFunc prim_func = tir::RenewDefs(prim_func_.value());
+
+    // Step 3. Check functions are all schedulable funcs. i.e. the body of 
func is root block
+    // TODO(Siyuan): support un-schedulable functions.
+    ICHECK(prim_func->body->IsInstance<tir::BlockRealizeNode>())
+        << "Only schedulable functions (whose body is the root block) can be 
fused";
+    const tir::BlockRealize& root_realize = 
Downcast<tir::BlockRealize>(prim_func->body);
+    const tir::Block& root_block = root_realize->block;
+
+    // Step 4. Add all the original alloc_buffers and body to the fused 
function.
+    func_info_.alloc_buffers.insert(func_info_.alloc_buffers.end(),
+                                    root_block->alloc_buffers.begin(),
+                                    root_block->alloc_buffers.end());
+    func_info_.bodies.push_back(root_block->body);
+
+    // Step 5. Map input arguments to buffer
+    MapInputBuffer(prim_func, call->args[1]);
+    size_t num_output_buffers = GetCallTIROutputSize(call);
+    AllocateIntermediateBuffer(GetRef<Expr>(call), prim_func, 
num_output_buffers);
+    // Update fused func name
+    func_info_.global_name += "_" + gv->name_hint;
+  }
+
+  void VisitExpr_(const TupleGetItemNode* tuple_get_item) final {
+    ExprVisitor::VisitExpr_(tuple_get_item);
+    auto it = func_info_.expr2buffers.find(tuple_get_item->tuple);
+    if (it != func_info_.expr2buffers.end()) {
+      int begin_buf_idx = 0;
+      int end_buf_idx = 0;
+      const TupleType& tuple_type = 
Downcast<TupleType>(tuple_get_item->tuple->checked_type());
+      for (int i = 0; i < tuple_get_item->index; ++i) {
+        begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
+      }
+      end_buf_idx = begin_buf_idx + 
GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]);
+      func_info_.expr2buffers.Set(
+          GetRef<Expr>(tuple_get_item),
+          {(*it).second.begin() + begin_buf_idx, (*it).second.begin() + 
end_buf_idx});
+    }
+  }
+
+  void VisitExpr_(const TupleNode* tuple) final {
+    ExprVisitor::VisitExpr_(tuple);
+    Array<tir::Buffer> buffers;
+    for (const Expr& expr : tuple->fields) {
+      auto it = func_info_.expr2buffers.find(expr);
+      if (it != func_info_.expr2buffers.end()) {
+        buffers.insert(buffers.end(), (*it).second.begin(), 
(*it).second.end());
+      }
+    }
+    if (!buffers.empty()) {
+      func_info_.expr2buffers.Set(GetRef<Expr>(tuple), buffers);
+    }
+  }
+
+  void VisitExpr_(const ConstantNode* op) final {
+    LOG(FATAL) << "Relax.Constant is not supported in primitive functions.";
+  }
+
+  /********** Helper Functions **********/
+
+  /*!
+   * \brief Pattern match op to a TIR function and look it up.
+   * \return The TIR function, or NullOpt if patter match fails.
+   */
+  Optional<tir::PrimFunc> GetPrimFunc(const GlobalVar& global_var) {
+    // NOTE: as check works for nullptr(returns null)
+    Optional<BaseFunc> base_func = mod_->functions.Get(global_var);
+    if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
+      return GetRef<tir::PrimFunc>(pfunc);
+    } else {
+      return NullOpt;
+    }
+  }
+
+  /*!
+   * \brief Get the number of outputs for a call_tir node.
+   * \return The number of outputs.
+   */
+  static size_t GetCallTIROutputSize(const CallNode* call) {
+    static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+    ICHECK(call->op.same_as(call_tir_op_));
+    ICHECK_EQ(call->sinfo_args.size(), 1);
+    if (const auto* tuple_sinfo = 
call->sinfo_args[0].as<TupleStructInfoNode>()) {
+      return tuple_sinfo->fields.size();
+    } else {
+      return 1;
+    }
+  }
+
+  /*! \brief Map old TIR func param buffer to new buffer, and then update 
`buffer_subst_map` */
+  void MapArgsToBuffer(const Array<Expr> args, const Array<tir::Buffer>& 
buffers) {
+    size_t buffer_idx = 0;
+    for (const Expr& arg : args) {
+      if (const auto* v = arg.as<VarNode>()) {
+        auto it = func_info_.expr2buffers.find(GetRef<Var>(v));
+        // Substitute the buffer with the already allocated one if it is an 
intermediate var
+        if (it != func_info_.expr2buffers.end()) {
+          for (const tir::Buffer& target_buffer : (*it).second) {
+            ICHECK_LT(buffer_idx, buffers.size());
+            const tir::Buffer& buffer = buffers[buffer_idx];
+            // TODO(relax-team): Add support for symbolic shape fusion
+            for (const PrimExpr& shape_expr : buffer->shape) {
+              ICHECK(shape_expr.as<IntImmNode>()) << "Only support constant 
shape fusion for now";
+            }
+            func_info_.buffer_subst_map.Set(buffer, target_buffer);
+            buffer_idx++;
+          }
+        }
+      }
+    }
+    // Make sure every buffers are maped.
+    ICHECK_EQ(buffer_idx, buffers.size());
+  }
+
+  /*!
+   * \brief Update buffer mapping `func_info_.buffer_subst_map` for input args
+   * \param func The old TIR PrimFunc
+   * \param output_size The number of output params. All output params are at 
the end of param list.
+   */
+  void MapInputBuffer(const tir::PrimFunc& func, const relax::Expr& args) {
+    Array<Expr> arg_list;
+    Array<tir::Buffer> buffer_list;
+    if (const auto* arg_tuple = args.as<TupleNode>()) {
+      arg_list = arg_tuple->fields;
+    } else {
+      arg_list = {args};
+    }
+
+    ICHECK_GE(func->params.size(), arg_list.size());
+    for (size_t i = 0; i < arg_list.size(); ++i) {
+      const tir::Var& param = func->params[i];
+      const tir::Buffer& buffer = func->buffer_map.at(param);
+      buffer_list.push_back(buffer);
+    }
+
+    MapArgsToBuffer(arg_list, buffer_list);
+  }
+
+  /*!
+   * \brief Allocate buffer(s) and update `func_info.expr2buffers` if the 
PrimFunc output(s) are
+   * intermediate results.
+   * \param expr The relax Expr, which can be binding vars or binding values.
+   * \param func The old TIR PrimFunc
+   * \param output_size The number of output params. All output params are at 
the end of param list.
+   */
+  void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func, 
size_t output_size) {
+    size_t n = func->params.size();
+    ICHECK_GE(n, output_size);
+    // Allocate intermediate buffer
+    Array<tir::Buffer> alloc_buffers;
+    for (size_t i = 0; i < output_size; ++i) {
+      const tir::Var& param = func->params[n - output_size + i];
+      const tir::Buffer& buffer = func->buffer_map.at(param);
+      func_info_.alloc_buffers.push_back(buffer);
+      alloc_buffers.push_back(buffer);
+    }
+    // Update expr2buffers
+    func_info_.expr2buffers.Set(expr, alloc_buffers);
+  }
+
+  /*!
+   * \brief Create an TIR func params and buffers with specified relax type 
and shape
+   * \param struct_info The struct info
+   * \param name_hint The name hint for params and buffers
+   * \param index The index used for unique name_hint if type is Tuple.
+   *              -1 means no need to add postfix since the relax param is not 
a Tuple.
+   * \return The created TIR func params and buffers
+   */
+  static std::pair<Array<tir::Var>, Array<tir::Buffer>> CreateParamsAndBuffers(
+      StructInfo struct_info, const String& name_hint, int index = -1) {
+    Array<tir::Var> params;
+    Array<tir::Buffer> buffers;
+    if (const auto* tensor = struct_info.as<TensorStructInfoNode>()) {
+      // Case 1. the relax param is a DynTensor, we directly create a tir var 
and buffer
+      const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
+      ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with 
symbolic shape.";
+
+      String name = index == -1 ? name_hint : name_hint + "_" + 
std::to_string(index);
+      DataType dtype = tensor->dtype;
+      tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name);
+      // Differentiate buffer name and param name by adding prefix `v_` to 
param
+      // Every symbol should be unique in TVMScript, and Buffer is used more 
than param
+      // So we decide to make sure buffer names have better readability.
+      tir::Var param = tir::Var("p_" + name, PrimType(DataType::Handle()));
+      params.push_back(std::move(param));
+      buffers.push_back(std::move(buffer));
+    } else if (const auto* tuple = struct_info.as<TupleStructInfoNode>()) {
+      // Case 2. the relax param is a Tuple, we recursively visit each field 
until it's a DynTensor
+      // Enable postfix
+      if (index == -1) index = 0;
+      for (size_t i = 0; i < tuple->fields.size(); ++i) {
+        auto ret = CreateParamsAndBuffers(tuple->fields[i], name_hint, index);
+        const Array<tir::Var>& ret_params = ret.first;
+        const Array<tir::Buffer>& ret_buffers = ret.second;
+        ICHECK_EQ(ret_params.size(), ret_buffers.size());
+        // Adding tuple field results to the end of params and buffers.
+        params.insert(params.end(), ret_params.begin(), ret_params.end());
+        buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end());
+        index += ret_params.size();
+      }
+    } else {
+      ICHECK(false) << "shapes are expected to be ShapeExprNode or TupleNode";
+    }
+    return std::make_pair(params, buffers);
+  }
+
+  /*!
+   * \brief Construct fused TIR func with collected FuseFuncInfo
+   * \return The fused TIR
+   */
+  tir::PrimFunc ConstructFunc() {
+    Map<String, ObjectRef> attr_map;
+    attr_map.Set("tir.noalias", tir::const_true());
+    ICHECK(func_info_.global_name != "fused");
+    // Remove output buffers from func_info_.alloc_buffers
+    Array<tir::Buffer> alloc_buffers;
+    for (const tir::Buffer& buf : func_info_.alloc_buffers) {
+      if (func_info_.output_buffers.count(buf.get()) == 0) {
+        alloc_buffers.push_back(buf);
+      }
+    }
+    tir::Stmt body = 
tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies));
+    body = 
tir::FuseTIRBufferSubstitor::Substitute(func_info_.buffer_subst_map, body);
+    body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt, 
alloc_buffers);
+    body = tir::BlockRealize({}, Bool(true), Downcast<tir::Block>(body));
+    tir::PrimFunc func(func_info_.params, body, VoidType(), 
func_info_.buffer_map,
+                       DictAttrs(attr_map));
+    return func;
+  }
+
+  /*! \brief Get DynTensor numbers from recursive Tuples. */
+  static size_t GetTotalTensorSize(const Type& type) {
+    if (type.as<DynTensorTypeNode>()) {
+      return 1;
+    } else if (const auto* tuple_type = type.as<TupleTypeNode>()) {
+      size_t num = 0;
+      for (const Type& type : tuple_type->fields) {
+        num += GetTotalTensorSize(type);
+      }
+      return num;
+    } else {
+      LOG(FATAL) << "DynTensorType and TupleType are expect, but got: " << 
type;
+      return 0;
+    }
+  }
+
+  /********** Function Info **********/
+
+  /*! \brief auxiliary information for FuseTIR */
+  struct FuseFuncInfo {
+    /*! \brief The arguments for calling prim_func */
+    Array<Expr> arguments;
+    /*!
+     * \brief The map from each dataflow var (intermediate var) to the 
corresponding buffers
+     * allocated in the fused func
+     */
+    Map<Expr, Array<tir::Buffer>> expr2buffers;
+    /*! \brief The buffers to allocate in the fused func*/
+    Array<tir::Buffer> alloc_buffers;
+    /*! \brief The bodies of the original funcs, which is also the body of the 
fused func. */
+    Array<tir::Stmt> bodies;
+    /*! \brief The params of the fused function*/
+    Array<tir::Var> params;
+    /*!
+     * \brief The map from buffer in original functions to corresponding 
buffer in the fused
+     * function
+     */
+    Map<tir::Buffer, tir::Buffer> buffer_subst_map;
+    /*! \brief The `buffer_map` in the fused function*/
+    Map<tir::Var, tir::Buffer> buffer_map;
+    /*! \brief The output buffers in the function buffer_map*/
+    std::unordered_set<const tir::BufferNode*> output_buffers;
+    /*! \brief The name of the fused function */
+    std::string global_name = "fused";
+  };
+
+  /*! \brief The IRModule */
+  const IRModule& mod_;
+  /*! \brief The name hint for the input func. */
+  String func_name_;
+  /*! \brief The helper info to fuse TIR prim_func */
+  FuseFuncInfo func_info_;
+  /*! \brief The tir function after fusion*/
+  tir::PrimFunc fused_tir_;
+};
+
+/*!
+ * \brief The helper class to fuse TIR functions and build a new module which 
calls the fused TIR.
+ */
+class TIRFuseMutator : public ExprMutator {
+ public:
+  static IRModule Transform(const IRModule& mod) {
+    // Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty 
block builder.
+    TIRFuseMutator mutator(mod);
+    // Step 1. Fuse all primitive relax functions, store the result in 
`fused_tir_funcs_`
+    for (const auto& kv : mod->functions) {
+      const GlobalVar& gv = kv.first;
+      const BaseFunc& func = kv.second;
+      // Only fuse primitive relax functions
+      if (func->IsInstance<relax::FunctionNode>() && 
func->HasNonzeroAttr(attr::kPrimitive)) {
+        tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv);
+        mutator.fused_tir_funcs_.Set(gv, fused_tir);
+      }
+    }
+
+    // Step 2. Update all non-primitive relax functions and add it, with the 
dependent function,
+    // into the new IRModule
+    for (const auto& kv : mod->functions) {
+      const GlobalVar& gv = kv.first;
+      const BaseFunc& func = kv.second;
+      if (func->IsInstance<relax::FunctionNode>() && 
!func->HasNonzeroAttr(attr::kPrimitive)) {
+        relax::Function update_func = 
Downcast<Function>(mutator.VisitExpr(func));
+        mutator.builder_->AddFunction(update_func, gv->name_hint);
+      }
+    }
+    return mutator.builder_->GetContextIRModule();
+  }
+
+ private:
+  explicit TIRFuseMutator(const IRModule& mod) : mod_(mod) {}
+
+  using ExprMutator::VisitExpr_;
+
+  // Get shape from call tir
+  static Expr GetCallTIRShape(StructInfo sinfo) {
+    if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
+      Array<Expr> fields = tuple->fields.Map([&](StructInfo x) { return 
GetCallTIRShape(x); });
+      return Tuple(fields);
+    } else {
+      auto* tensor = sinfo.as<TensorStructInfoNode>();
+      ICHECK(tensor) << "FuseTIR can only take tensor or tuple type";
+      auto* shape_expr = tensor->shape.as<ShapeExprNode>();
+      ICHECK(shape_expr) << "FuseTIR requires all intermediate values have 
shape";
+      return GetRef<ShapeExpr>(shape_expr);
+    }
+  }
+
+  Expr VisitExpr_(const CallNode* op) final {
+    static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+    Call call = 
Downcast<Call>(builder_->Normalize(ExprMutator::VisitExpr_(op)));
+
+    if (call->op->IsInstance<GlobalVarNode>()) {
+      // Case 1. It is a relax cross function call
+      GlobalVar old_gv = Downcast<GlobalVar>(call->op);
+      auto it = fused_tir_funcs_.find(old_gv);
+      if (it != fused_tir_funcs_.end()) {
+        const tir::PrimFunc& fused_tir = (*it).second;
+        // Case 1.1. It calls a primitive relax function, update the call into 
a call_tir
+        GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir, 
old_gv->name_hint);
+        // Step a. Flatten all args since call_tir does not support Tuple 
value.
+        Array<Expr> arg_list;
+        for (const Expr& arg : call->args) {
+          Array<Expr> flattened = FlattenArg(arg);
+          arg_list.insert(arg_list.end(), flattened.begin(), flattened.end());
+        }
+        // Step b. Create call_tir
+        Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
+        return Call(call_tir_op_, call_args, call->attrs, 
{GetStructInfo(call)});
+      } else {
+        // Case 1.2. The callee function is not primitive, nothing to do.
+        return call;
+      }
+    } else if (call->op == call_tir_op_) {
+      // Case 2. It is a call_tir, re-emit the PrimFunc.
+      GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+      tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+      GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
+      return Call(call->op, {new_gv, call->args[1]}, call->attrs, 
call->sinfo_args, call->span);
+    } else {
+      // Case 3. CallNode in other types. Leave it as it is.
+      return call;
+    }
+  }
+
+  /********** Helper Functions **********/
+
+  /*! \brief Flatten the call args if it's Tuple by emitting `TupleGetItem`. */
+  Array<Expr> FlattenArg(const Expr& arg) {
+    if (const auto* tuple_sinfo = GetStructInfoAs<TupleStructInfoNode>(arg)) {
+      Array<Expr> arg_list;
+      for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) {
+        Expr new_arg = builder_->Emit(TupleGetItem(arg, i));
+        Array<Expr> flattened = FlattenArg(new_arg);
+        arg_list.insert(arg_list.end(), flattened.begin(), flattened.end());
+      }
+      return arg_list;
+    } else {
+      return {arg};
+    }
+  }
+
+ private:
+  /*! \brief The IRModule */
+  const IRModule& mod_;
+  /*! \brief The map from global var of primitive relax function to generated 
prim func. */
+  Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
+};
+
+IRModule FuseTIR(IRModule mod) {
+  mod = TIRFuseMutator::Transform(mod);
+  return mod;
+}
+
+namespace transform {
+
+Pass FuseTIR() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
+      [=](IRModule m, PassContext pc) { return relax::FuseTIR(m); };
+  return CreateModulePass(/*pass_function=*/pass_func,  //
+                          /*opt_level=*/0,              //
+                          /*pass_name=*/"FuseTIR",      //
+                          /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py 
b/tests/python/relax/test_transform_annotate_tir_op_pattern.py
new file mode 100644
index 0000000000..73c6537869
--- /dev/null
+++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py
@@ -0,0 +1,360 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import enum
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax
+from tvm.script import tir as T
+
+
+class OpPatternKind(enum.IntEnum):
+    kElemWise = 0
+    kBroadcast = 1
+    kInjective = 2
+    kCommReduce = 3
+    kOutEWiseFusable = 4
+    kTuple = 7
+    kOpaque = 8
+
+
+def test_annotate_opkind_outewisefusable():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+            T.func_attr({"global_symbol": "tir_matmul"})
+            m = T.var("int32")
+            n = T.var("int32")
+            k = T.var("int32")
+            A = T.match_buffer(x, (m, n))
+            B = T.match_buffer(y, (n, k))
+            C = T.match_buffer(z, (m, k))
+
+            for i, j, k in T.grid(m, k, n):
+                with T.block("matmul"):
+                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                    with T.init():
+                        C[vi, vj] = T.float32(0)
+                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+    mod = InputModule
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["tir_matmul"].attrs["op_pattern"] == 
OpPatternKind.kOutEWiseFusable
+
+
+def test_annotate_opkind_outewisefusable_int_var_signature():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def tir_matmul(x: T.handle, y: T.handle, z: T.handle, m: T.int64, n: 
T.int64, k: T.int64):
+            T.func_attr({"global_symbol": "tir_matmul"})
+            A = T.match_buffer(x, (m, n))
+            B = T.match_buffer(y, (n, k))
+            C = T.match_buffer(z, (m, k))
+
+            for i, j, k in T.grid(m, k, n):
+                with T.block("matmul"):
+                    vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                    with T.init():
+                        C[vi, vj] = T.float32(0)
+                    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+    mod = InputModule
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["tir_matmul"].attrs["op_pattern"] == 
OpPatternKind.kOutEWiseFusable
+
+
+def test_annotate_opkind_reduce():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def sum(x: T.handle, y: T.handle) -> None:
+            T.func_attr({"global_symbol": "elemwise"})
+            A = T.match_buffer(x, (16, 16))
+            B = T.match_buffer(y, (16,))
+
+            for i, j in T.grid(16, 16):
+                with T.block("matmul"):
+                    vi, vj = T.axis.remap("SR", [i, j])
+                    with T.init():
+                        B[vi] = 0.0
+                    B[vi] += A[vi, vj]
+
+    mod = InputModule
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["sum"].attrs["op_pattern"] == OpPatternKind.kCommReduce
+
+
+def test_annotate_opkind_ewise():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def elemwise(x: T.handle, y: T.handle) -> None:
+            T.func_attr({"global_symbol": "elemwise"})
+            A = T.match_buffer(x, (16, 16))
+            B = T.match_buffer(y, (16, 16))
+
+            for i, j in T.grid(16, 16):
+                with T.block("matmul"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi, vj] + 1.0
+
+    mod = InputModule
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["elemwise"].attrs["op_pattern"] == OpPatternKind.kElemWise
+
+
+def test_annotate_opkind_broadcast():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def broadcast(x: T.handle, y: T.handle) -> None:
+            T.func_attr({"global_symbol": "elemwise"})
+            A = T.match_buffer(x, (16, 16))
+            B = T.match_buffer(y, (16, 16, 16, 16))
+
+            for i0, j0, i1, j1 in T.grid(16, 16, 16, 16):
+                with T.block("matmul"):
+                    vi0, vj0, vi1, vj1 = T.axis.remap("SSSS", [i0, j0, i1, j1])
+                    B[vi0, vj0, vi1, vj1] = A[vj0, vj1]
+
+    mod = InputModule
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["broadcast"].attrs["op_pattern"] == OpPatternKind.kBroadcast
+
+
+def test_annotate_opkind_injective():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def injective(x: T.handle, y: T.handle) -> None:
+            T.func_attr({"global_symbol": "elemwise"})
+            A = T.match_buffer(x, (4, 4, 4, 4))
+            B = T.match_buffer(y, (16, 16))
+
+            for i, j in T.grid(16, 16):
+                with T.block("matmul"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi // 4, vj // 4, vi % 4, vj % 4]
+
+    mod = InputModule
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["injective"].attrs["op_pattern"] == OpPatternKind.kInjective
+
+
+def test_annotate_opkind_bias_add():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def tir_bias_add(
+            A: T.Buffer((1, 1000), "float32"),
+            B: T.Buffer((1000,), "float32"),
+            C: T.Buffer((1, 1000), "float32"),
+        ) -> None:
+            # function attr dict
+            T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True})
+            # body
+            # with T.block("root")
+            for i0, i1 in T.grid(1, 1000):
+                with T.block("T_add"):
+                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(A[ax0, ax1], B[ax1])
+                    T.writes(C[ax0, ax1])
+                    C[ax0, ax1] = A[ax0, ax1] + B[ax1]
+
+    mod = InputModule
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["tir_bias_add"].attrs["op_pattern"] == 
OpPatternKind.kElemWise
+
+
+def test_annotate_opkind_add_broadcast_with_unit_shape():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def add_with_unit_dim_len_broadcast(
+            A: T.Buffer((1, 64, 112, 112), "float32"),
+            B: T.Buffer((64, 1, 1), "float32"),
+            C: T.Buffer((1, 64, 112, 112), "float32"),
+        ) -> None:
+            T.func_attr({"global_symbol": "add5", "tir.noalias": True})
+            for i0, i1, i2, i3 in T.grid(1, 64, 112, 112):
+                with T.block("T_add"):
+                    ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                    T.reads(A[ax0, ax1, ax2, ax3], B[ax1, 0, 0])
+                    T.writes(C[ax0, ax1, ax2, ax3])
+                    C[ax0, ax1, ax2, ax3] = A[ax0, ax1, ax2, ax3] + B[ax1, 0, 
0]
+
+    mod = InputModule
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["add_with_unit_dim_len_broadcast"].attrs["op_pattern"] == 
OpPatternKind.kElemWise
+
+
+def test_annotate_opkind_add_zero_dim_element_wise():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def add_zero_dim(
+            A: T.Buffer((128,), "float32"),
+            B: T.Buffer((), "float32"),
+            C: T.Buffer((128,), "float32"),
+        ) -> None:
+            T.func_attr({"global_symbol": "add8", "tir.noalias": True})
+            for i0 in T.serial(128):
+                with T.block("T_add"):
+                    ax0 = T.axis.spatial(128, i0)
+                    T.reads(A[ax0], B[()])
+                    T.writes(C[ax0])
+                    C[ax0] = A[ax0] + B[()]
+
+    mod = InputModule
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["add_zero_dim"].attrs["op_pattern"] == 
OpPatternKind.kElemWise
+
+
+def test_annotate_opkind_pooling():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def max_pool2d(
+            rxplaceholder_1: T.Buffer((1, 64, 112, 112), "float32"),
+            tensor_1: T.Buffer((1, 64, 56, 56), "float32"),
+        ) -> None:
+            # function attr dict
+            T.func_attr({"global_symbol": "max_pool2d", "T.noalias": True})
+            # body
+            # with T.block("root")
+            pad_temp_1 = T.alloc_buffer([1, 64, 114, 114], dtype="float32")
+            for i0, i1, i2, i3 in T.grid(1, 64, 114, 114):
+                with T.block("pad_temp"):
+                    ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+                    T.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1])
+                    T.writes(pad_temp_1[ax0, ax1, ax2, ax3])
+                    pad_temp_1[ax0, ax1, ax2, ax3] = T.if_then_else(
+                        1 <= ax2 and ax2 < 113 and 1 <= ax3 and ax3 < 113,
+                        rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1],
+                        T.float32(-3.4028234663852886e38),
+                        dtype="float32",
+                    )
+            for i0, i1, i2, i3, i4, i5 in T.grid(1, 64, 56, 56, 3, 3):
+                with T.block("tensor"):
+                    ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, 
i1, i2, i3, i4, i5])
+                    T.reads(
+                        tensor_1[ax0, ax1, ax2, ax3],
+                        pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1],
+                    )
+                    T.writes(tensor_1[ax0, ax1, ax2, ax3])
+                    with T.init():
+                        tensor_1[ax0, ax1, ax2, ax3] = 
T.float32(-3.4028234663852886e38)
+                    tensor_1[ax0, ax1, ax2, ax3] = T.max(
+                        tensor_1[ax0, ax1, ax2, ax3],
+                        pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1],
+                    )
+
+    mod = InputModule
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["max_pool2d"].attrs["op_pattern"] == 
OpPatternKind.kOutEWiseFusable
+
+
+def test_annotate_opkind_softmax():
+    @tvm.script.ir_module
+    class InputModule:
+        @T.prim_func
+        def softmax(
+            rxplaceholder_1: T.Buffer((16, 16), "float32"),
+            T_softmax_norm_1: T.Buffer((16, 16), "float32"),
+        ) -> None:
+            # function attr dict
+            T.func_attr({"global_symbol": "softmax", "T.noalias": True})
+            # body
+            # with T.block("root")
+            T_softmax_maxelem_1 = T.alloc_buffer([16], dtype="float32")
+            T_softmax_exp_1 = T.alloc_buffer([16, 16], dtype="float32")
+            T_softmax_expsum_1 = T.alloc_buffer([16], dtype="float32")
+            for i0_7, i1_3 in T.grid(16, 16):
+                with T.block("T_softmax_maxelem"):
+                    i0_8, k = T.axis.remap("SR", [i0_7, i1_3])
+                    T.reads(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, 
k])
+                    T.writes(T_softmax_maxelem_1[i0_8])
+                    with T.init():
+                        T_softmax_maxelem_1[i0_8] = 
T.float32(-3.4028234663852886e38)
+                    T_softmax_maxelem_1[i0_8] = T.max(
+                        T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k]
+                    )
+            for i0_9, i1_4 in T.grid(16, 16):
+                with T.block("T_softmax_exp"):
+                    i0_10, i1_5 = T.axis.remap("SS", [i0_9, i1_4])
+                    T.reads(rxplaceholder_1[i0_10, i1_5], 
T_softmax_maxelem_1[i0_10])
+                    T.writes(T_softmax_exp_1[i0_10, i1_5])
+                    T_softmax_exp_1[i0_10, i1_5] = T.exp(
+                        rxplaceholder_1[i0_10, i1_5] - 
T_softmax_maxelem_1[i0_10], dtype="float32"
+                    )
+            for i0_11, i1_6 in T.grid(16, 16):
+                with T.block("T_softmax_expsum"):
+                    i0_12, k = T.axis.remap("SR", [i0_11, i1_6])
+                    T.reads(T_softmax_expsum_1[i0_12], T_softmax_exp_1[i0_12, 
k])
+                    T.writes(T_softmax_expsum_1[i0_12])
+                    with T.init():
+                        T_softmax_expsum_1[i0_12] = T.float32(0)
+                    T_softmax_expsum_1[i0_12] = (
+                        T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k]
+                    )
+            for i0_13, i1_7 in T.grid(16, 16):
+                with T.block("T_softmax_norm"):
+                    i0_14, i1_8 = T.axis.remap("SS", [i0_13, i1_7])
+                    T.reads(T_softmax_exp_1[i0_14, i1_8], 
T_softmax_expsum_1[i0_14])
+                    T.writes(T_softmax_norm_1[i0_14, i1_8])
+                    T.block_attr({"axis": 1})
+                    T_softmax_norm_1[i0_14, i1_8] = (
+                        T_softmax_exp_1[i0_14, i1_8] / 
T_softmax_expsum_1[i0_14]
+                    )
+
+    mod = InputModule
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["softmax"].attrs["op_pattern"] == 
OpPatternKind.kOutEWiseFusable
+
+
+def test_multiple_bufer_stores_fallback():
+    @tvm.script.ir_module
+    class CumsumModule:
+        @T.prim_func
+        def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer(160, 
"float32")):
+            rxplaceholder = T.match_buffer(
+                var_rxplaceholder, [10, 16], dtype="float32", offset_factor=1
+            )
+            with T.block("cumsum_generic"):
+                T.reads(rxplaceholder[0:10, 0:16])
+                T.writes(out_buf[0:160])
+                for fused in T.parallel(1):
+                    out_buf[fused * 160] = rxplaceholder[fused * 160 // 16, 
fused * 160 % 16]
+                    for v_k in T.serial(159):
+                        out_buf[fused * 160 + (v_k + 1)] = (
+                            out_buf[fused * 160 + (v_k + 1 - 1)]
+                            + rxplaceholder[
+                                (fused * 160 + (v_k + 1)) // 16,
+                                (fused * 160 + (v_k + 1)) % 16,
+                            ]
+                        )
+
+    mod = CumsumModule
+    new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+    assert new_mod["cumsum"].attrs["op_pattern"] == OpPatternKind.kOpaque
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fuse_ops.py 
b/tests/python/relax/test_transform_fuse_ops.py
new file mode 100644
index 0000000000..1a228bb268
--- /dev/null
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -0,0 +1,759 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import relax, topi
+from tvm.script import relax as R
+
+
+def _check(mod_actual, mod_expected):
+    mod_actual = relax.transform.AnnotateTIROpPattern()(mod_actual)
+    mod_actual = relax.transform.FuseOps()(mod_actual)
+    mod_expected = relax.transform.AnnotateTIROpPattern()(mod_expected)
+    tvm.ir.assert_structural_equal(mod_actual, mod_expected)
+
+
+def test_fuse_simple():
+    """Simple testcase."""
+
+    def before():
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+                lv1 = bb.emit_te(topi.exp, lv0)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected():
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        p0 = relax.Var("p0", R.Tensor((), "float32"))
+
+        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 
1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, p0)
+                lv1 = bb.emit_te(topi.exp, lv0)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+            bb.emit_func_output(gv)
+        fused_add_exp_squeeze = 
bb.get().get_global_var("fused_add_exp_squeeze")
+
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                gv = bb.emit_output(
+                    relax.Call(fused_add_exp_squeeze, [x, relax.const(1, 
"float32")])
+                )
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_conv2d_fuse():
+    """Test fusion case of conv2d"""
+
+    def before(dtype):
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+        w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
+        w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
+        w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
+        with bb.function("main", [x, w1, w2, w3]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
+                lv1 = bb.emit_te(topi.nn.conv2d, lv0, w1, strides=1, 
padding=1, dilation=1)
+                # this is the next dominator.
+                lv2 = bb.emit_te(topi.add, relax.const(1, dtype), lv1)
+                lv3 = bb.emit_te(topi.add, lv1, lv2)
+                # second path
+                lv4 = bb.emit_te(topi.nn.conv2d, lv3, w2, strides=1, 
padding=0, dilation=1)
+                lv5 = bb.emit_te(topi.nn.conv2d, lv3, w3, strides=1, 
padding=1, dilation=1)
+                gv = bb.emit_output(bb.call_te(topi.add, lv4, lv5))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected(dtype):
+        bb = relax.BlockBuilder()
+
+        # Grouped function 1
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+        w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype))
+        p0 = relax.Var("p0", R.Tensor((), dtype))
+        with bb.function("fused_conv2d_add1_add2", [x, w, p0], 
attrs={"Primitive": 1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(
+                    topi.nn.conv2d,
+                    x,
+                    w,
+                    strides=1,
+                    padding=1,
+                    dilation=1,
+                    primfunc_name_hint="conv2d",
+                )
+                lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1")
+                gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, 
primfunc_name_hint="add2"))
+            bb.emit_func_output(gv)
+
+        # Grouped function 2
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+        w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype))
+        y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype))
+        with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": 
1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(
+                    topi.nn.conv2d,
+                    x,
+                    w,
+                    strides=1,
+                    padding=0,
+                    dilation=1,
+                    primfunc_name_hint="conv2d1",
+                )
+                gv = bb.emit_output(bb.call_te(topi.add, lv0, y, 
primfunc_name_hint="add2"))
+            bb.emit_func_output(gv)
+
+        # Get the global variables of the grouped functions
+        mod = bb.get()
+        fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2")
+        fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2")
+
+        # Main function
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+        w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
+        w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
+        w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
+        with bb.function("main", [x, w1, w2, w3]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
+                lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, 
relax.const(1, dtype)]))
+                lv2 = bb.emit_te(
+                    topi.nn.conv2d,
+                    lv1,
+                    w3,
+                    strides=1,
+                    padding=1,
+                    dilation=1,
+                )
+                gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2, 
lv2]))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    _check(before("float32"), expected("float32"))
+    _check(before("float16"), expected("float16"))
+    _check(before("int8"), expected("int8"))
+
+
+def test_concatenate():
+    """Test fusion case involving concat op and Tuple node"""
+
+    def before():
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(
+                    topi.nn.pool2d,
+                    x,
+                    kernel=(2, 2),
+                    stride=(2, 2),
+                    dilation=(1, 1),
+                    padding=(0, 0, 0, 0),
+                    pool_type="max",
+                )
+                lv1 = bb.emit_te(topi.nn.upsampling, lv0, scale_h=2.0, 
scale_w=2.0)
+                lv2 = bb.emit_te(topi.concatenate, (lv1, x), axis=1)
+                gv = bb.emit_output(bb.call_te(topi.add, lv2, relax.const(1, 
"float32")))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected():
+        bb = relax.BlockBuilder()
+
+        # Grouped function
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        w = relax.Var("w", R.Tensor((1, 16, 32, 32), "float32"))
+        p0 = relax.Var("p0", R.Tensor((), "float32"))
+        with bb.function("fused_upsampling_concatenate_add", [w, x, p0], 
attrs={"Primitive": 1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.nn.upsampling, w, scale_h=2.0, 
scale_w=2.0)
+                lv1 = bb.emit_te(topi.concatenate, (lv0, x), axis=1)
+                gv = bb.emit_output(bb.call_te(topi.add, lv1, p0))
+            bb.emit_func_output(gv)
+
+        # Get the global variables of the grouped functions
+        fused_upsampling_concatenate_add = bb.get().get_global_var(
+            "fused_upsampling_concatenate_add"
+        )
+
+        # Main function
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(
+                    topi.nn.pool2d,
+                    x,
+                    kernel=(2, 2),
+                    stride=(2, 2),
+                    dilation=(1, 1),
+                    padding=(0, 0, 0, 0),
+                    pool_type="max",
+                )
+                gv = bb.emit_output(
+                    relax.Call(
+                        fused_upsampling_concatenate_add, (lv0, x, 
relax.const(1, "float32"))
+                    )
+                )
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_tuple_root():
+    """Test fusion case where Tuple node is the root in its group"""
+
+    def before():
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(
+                    topi.nn.pool2d,
+                    x,
+                    kernel=(2, 2),
+                    stride=(2, 2),
+                    dilation=(1, 1),
+                    padding=(0, 0, 0, 0),
+                    pool_type="max",
+                )
+                lv1 = bb.emit_te(topi.nn.upsampling, lv0, scale_h=2.0, 
scale_w=2.0)
+                gv = bb.emit_output((lv1, x))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    # The fusion is supposed to make no change.
+    _check(before(), before())
+
+
+def test_fuse_tuple_get_elemwise():
+    def before(dim: int):
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor((1, dim), "float32"))
+        w = relax.Var("w", R.Tensor((3 * dim, dim), "float32"))
+        with bb.function("main", [x, w]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.nn.dense, x, w)
+                lv1 = bb.emit_te(topi.split, lv0, indices_or_sections=3, 
axis=1)
+                lv2 = bb.emit(relax.TupleGetItem(lv1, 0))
+                lv3 = bb.emit_te(topi.sigmoid, lv2)
+                lv4 = bb.emit(relax.TupleGetItem(lv1, 1))
+                lv5 = bb.emit_te(topi.tanh, lv4)
+                lv6 = bb.emit(relax.TupleGetItem(lv1, 2))
+                lv7 = bb.emit_te(topi.exp, lv6)
+                lv8 = bb.emit_te(topi.multiply, lv5, lv7)
+                gv = bb.emit_output(bb.call_te(topi.add, lv3, lv8))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected(dim: int):
+        bb = relax.BlockBuilder()
+
+        # Grouped function
+        dense = relax.Var("dense", R.Tensor((1, 3 * dim), "float32"))
+        with bb.function(
+            "fused_split_sigmoid_tanh_exp_multiply_add", [dense], 
attrs={"Primitive": 1}
+        ):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.split, dense, indices_or_sections=3, 
axis=1)
+                lv1 = bb.emit(relax.TupleGetItem(lv0, 0))
+                lv2 = bb.emit_te(topi.sigmoid, lv1)
+                lv3 = bb.emit(relax.TupleGetItem(lv0, 1))
+                lv4 = bb.emit_te(topi.tanh, lv3)
+                lv5 = bb.emit(relax.TupleGetItem(lv0, 2))
+                lv6 = bb.emit_te(topi.exp, lv5)
+                lv7 = bb.emit_te(topi.multiply, lv4, lv6)
+                gv = bb.emit_output(bb.call_te(topi.add, lv2, lv7))
+            bb.emit_func_output(gv)
+
+        # Get the global variables of the grouped functions
+        fused_split_sigmoid_tanh_exp_multiply_add = bb.get().get_global_var(
+            "fused_split_sigmoid_tanh_exp_multiply_add"
+        )
+
+        # Main function
+        x = relax.Var("x", R.Tensor((1, dim), "float32"))
+        w = relax.Var("w", R.Tensor((3 * dim, dim), "float32"))
+        with bb.function("main", [x, w]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.nn.dense, x, w)
+                gv = 
bb.emit_output(relax.Call(fused_split_sigmoid_tanh_exp_multiply_add, (lv0,)))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    dim = 10
+    _check(before(dim), expected(dim))
+
+
+def test_tuple_get_root():
+    def before(dim: int):
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor((1, 3 * dim), "float32"))
+        w = relax.Var("w", R.Tensor((dim, dim), "float32"))
+        with bb.function("main", [x, w]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1)
+                lv1 = bb.emit(relax.TupleGetItem(lv0, 0))
+                gv = bb.emit_output(bb.call_te(topi.nn.dense, lv1, w))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected(dim: int):
+        bb = relax.BlockBuilder()
+
+        # Grouped function
+        x = relax.Var("x", R.Tensor((1, 3 * dim), "float32"))
+        with bb.function("fused_split", [x], attrs={"Primitive": 1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1)
+                gv = bb.emit_output(relax.TupleGetItem(lv0, 0))
+            bb.emit_func_output(gv)
+
+        # Get the global variables of the grouped functions
+        fused_split = bb.get().get_global_var("fused_split")
+
+        # Main function
+        x = relax.Var("x", R.Tensor((1, 3 * dim), "float32"))
+        w = relax.Var("w", R.Tensor((dim, dim), "float32"))
+        with bb.function("main", [x, w]):
+            with bb.dataflow():
+                lv0 = bb.emit(relax.Call(fused_split, (x,)))
+                gv = bb.emit_output(bb.call_te(topi.nn.dense, lv0, w))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    dim = 10
+    _check(before(dim), expected(dim))
+
+
+def test_tuple_intermediate():
+    def before():
+        bb = relax.BlockBuilder()
+
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.squeeze, x)
+                lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32"))
+                lv2 = bb.emit_te(topi.squeeze, lv0)
+                lv3 = bb.emit_te(topi.add, lv2, relax.const(1, "float32"))
+                lv4 = bb.emit_te(topi.add, lv3, relax.const(1, "float32"))
+                lv5 = bb.emit_te(topi.add, lv0, relax.const(1, "float32"))
+                lv6 = bb.emit_te(topi.concatenate, (lv1, lv4, lv5), axis=1)
+                lv7 = bb.emit_te(topi.squeeze, lv6)
+                gv = bb.emit_output(bb.call_te(topi.add, lv7, relax.const(1, 
"float32")))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected():
+        bb = relax.BlockBuilder()
+
+        # Grouped function
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        p0 = relax.Var("p0", R.Tensor((), "float32"))
+        p1 = relax.Var("p1", R.Tensor((), "float32"))
+        p2 = relax.Var("p2", R.Tensor((), "float32"))
+        p3 = relax.Var("p3", R.Tensor((), "float32"))
+        p4 = relax.Var("p4", R.Tensor((), "float32"))
+        with bb.function(
+            "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1",
+            [x, p0, p1, p2, p3, p4],
+            attrs={"Primitive": 1},
+        ):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.squeeze, x)
+                lv1 = bb.emit_te(topi.add, lv0, p0)
+                lv2 = bb.emit_te(topi.squeeze, lv0)
+                lv3 = bb.emit_te(topi.add, lv2, p1)
+                lv4 = bb.emit_te(topi.add, lv3, p2)
+                lv5 = bb.emit_te(topi.add, lv0, p3)
+                lv6 = bb.emit_te(topi.concatenate, (lv1, lv4, lv5), axis=1)
+                lv7 = bb.emit_te(topi.squeeze, lv6)
+                gv = bb.emit_output(bb.call_te(topi.add, lv7, p4))
+            bb.emit_func_output(gv)
+
+        # Get the global variables of the grouped functions
+        fused_func = bb.get().get_global_var(
+            "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1"
+        )
+
+        # Main func
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                gv = bb.emit_output(
+                    relax.Call(
+                        fused_func,
+                        (
+                            x,
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                        ),
+                    )
+                )
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_tuple_consecutive():
+    def before():
+        bb = relax.BlockBuilder()
+
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+                lv1 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+                lv2 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+                lv3 = bb.emit_te(topi.concatenate, (lv0, lv1, lv2), axis=1)
+                lv4 = bb.emit_te(topi.add, lv3, relax.const(1, "float32"))
+                lv5 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+                lv6 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+                lv7 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+                lv8 = bb.emit_te(topi.concatenate, (lv5, lv6, lv7), axis=1)
+                lv9 = bb.emit_te(topi.add, lv8, relax.const(1, "float32"))
+                lv10 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+                lv11 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+                lv12 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+                lv13 = bb.emit_te(topi.concatenate, (lv10, lv11, lv12), axis=1)
+                lv14 = bb.emit_te(topi.add, lv13, relax.const(1, "float32"))
+                lv15 = bb.emit_te(topi.concatenate, (lv4, lv9, lv14), axis=1)
+                lv16 = bb.emit_te(
+                    topi.nn.pool2d,
+                    lv15,
+                    kernel=(2, 2),
+                    stride=(2, 2),
+                    dilation=(1, 1),
+                    padding=(0, 0, 0, 0),
+                    pool_type="max",
+                )
+                lv17 = bb.emit_te(topi.add, lv16, relax.const(1, "float32"))
+                lv18 = bb.emit_te(topi.add, lv17, relax.const(1, "float32"))
+                gv = bb.emit_output((lv17, lv18))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected():
+        bb = relax.BlockBuilder()
+
+        # Grouped function 1
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        p0 = relax.Var("p0", R.Tensor((), "float32"))
+        p1 = relax.Var("p1", R.Tensor((), "float32"))
+        p2 = relax.Var("p2", R.Tensor((), "float32"))
+        p3 = relax.Var("p3", R.Tensor((), "float32"))
+        p4 = relax.Var("p4", R.Tensor((), "float32"))
+        p5 = relax.Var("p5", R.Tensor((), "float32"))
+        p6 = relax.Var("p6", R.Tensor((), "float32"))
+        p7 = relax.Var("p7", R.Tensor((), "float32"))
+        p8 = relax.Var("p8", R.Tensor((), "float32"))
+        p9 = relax.Var("p9", R.Tensor((), "float32"))
+        p10 = relax.Var("p10", R.Tensor((), "float32"))
+        p11 = relax.Var("p11", R.Tensor((), "float32"))
+        with bb.function(
+            
"fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1",
+            [x, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11],
+            attrs={"Primitive": 1},
+        ):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, p0)
+                lv1 = bb.emit_te(topi.add, x, p1)
+                lv2 = bb.emit_te(topi.add, x, p2)
+                lv3 = bb.emit_te(topi.concatenate, (lv0, lv1, lv2), axis=1)
+                lv4 = bb.emit_te(topi.add, lv3, p3)
+                lv5 = bb.emit_te(topi.add, x, p4)
+                lv6 = bb.emit_te(topi.add, x, p5)
+                lv7 = bb.emit_te(topi.add, x, p6)
+                lv8 = bb.emit_te(topi.concatenate, (lv5, lv6, lv7), axis=1)
+                lv9 = bb.emit_te(topi.add, lv8, p7)
+                lv10 = bb.emit_te(topi.add, x, p8)
+                lv11 = bb.emit_te(topi.add, x, p9)
+                lv12 = bb.emit_te(topi.add, x, p10)
+                lv13 = bb.emit_te(topi.concatenate, (lv10, lv11, lv12), axis=1)
+                lv14 = bb.emit_te(topi.add, lv13, p11)
+                gv = bb.emit_output(bb.call_te(topi.concatenate, (lv4, lv9, 
lv14), axis=1))
+            bb.emit_func_output(gv)
+
+        # Grouped function 2
+        concat = relax.Var("concat", R.Tensor((1, 144, 64, 64), "float32"))
+        p0 = relax.Var("p0", R.Tensor((), "float32"))
+        with bb.function("fused_pool2d_add2", [concat, p0], 
attrs={"Primitive": 1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(
+                    topi.nn.pool2d,
+                    concat,
+                    kernel=(2, 2),
+                    stride=(2, 2),
+                    dilation=(1, 1),
+                    padding=(0, 0, 0, 0),
+                    pool_type="max",
+                )
+                gv = bb.emit_output(bb.call_te(topi.add, lv0, p0))
+            bb.emit_func_output(gv)
+
+        # Get the global variables of the grouped functions
+        mod = bb.get()
+        fused_func1 = mod.get_global_var(
+            
"fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1"
+        )
+        fused_func2 = mod.get_global_var("fused_pool2d_add2")
+
+        # Main function
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit(
+                    relax.Call(
+                        fused_func1,
+                        (
+                            x,
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                            relax.const(1, "float32"),
+                        ),
+                    )
+                )
+                lv1 = bb.emit(relax.Call(fused_func2, (lv0, relax.const(1, 
"float32"))))
+                lv2 = bb.emit_te(topi.add, lv1, relax.const(1, "float32"))
+                gv = bb.emit_output((lv1, lv2))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_inception_like():
+    def before():
+        bb = relax.BlockBuilder()
+
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        w0 = relax.Var("w0", R.Tensor((16, 16, 3, 3), "float32"))
+        w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), "float32"))
+        w2 = relax.Var("w2", R.Tensor((16, 32, 3, 3), "float32"))
+        w3 = relax.Var("w3", R.Tensor((16, 32, 3, 3), "float32"))
+        with bb.function("main", [x, w0, w1, w2, w3]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.nn.conv2d, x, w0, strides=1, padding=1, 
dilation=1)
+                lv1 = bb.emit_te(topi.nn.relu, lv0)
+                lv2 = bb.emit_te(topi.nn.conv2d, x, w1, strides=1, padding=1, 
dilation=1)
+                lv3 = bb.emit_te(topi.nn.relu, lv2)
+                lv4 = bb.emit_te(topi.concatenate, (lv1, lv3), axis=1)
+                lv5 = bb.emit_te(topi.nn.conv2d, lv4, w2, strides=1, 
padding=1, dilation=1)
+                lv6 = bb.emit_te(topi.nn.relu, lv5)
+                lv7 = bb.emit_te(topi.nn.conv2d, lv4, w3, strides=1, 
padding=1, dilation=1)
+                lv8 = bb.emit_te(topi.nn.relu, lv7)
+                gv = bb.emit_output(bb.call_te(topi.concatenate, (lv6, lv8), 
axis=1))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected():
+        bb = relax.BlockBuilder()
+
+        # Grouped function 1
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        w = relax.Var("w", R.Tensor((16, 16, 3, 3), "float32"))
+        with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(
+                    topi.nn.conv2d,
+                    x,
+                    w,
+                    strides=1,
+                    padding=1,
+                    dilation=1,
+                    primfunc_name_hint="conv2d",
+                )
+                gv = bb.emit_output(bb.call_te(topi.nn.relu, lv0))
+            bb.emit_func_output(gv)
+
+        # Grouped function 2
+        x = relax.Var("x", R.Tensor((1, 32, 64, 64), "float32"))
+        w = relax.Var("w", R.Tensor((16, 32, 3, 3), "float32"))
+        with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(
+                    topi.nn.conv2d,
+                    x,
+                    w,
+                    strides=1,
+                    padding=1,
+                    dilation=1,
+                    primfunc_name_hint="conv2d1",
+                )
+                gv = bb.emit_output(bb.call_te(topi.nn.relu, lv0))
+            bb.emit_func_output(gv)
+
+        # Get the global variables of the grouped functions
+        mod = bb.get()
+        fused_conv2d_relu1 = mod.get_global_var("fused_conv2d_relu")
+        fused_conv2d_relu2 = mod.get_global_var("fused_conv2d1_relu")
+
+        # Main function
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+        w0 = relax.Var("w0", R.Tensor((16, 16, 3, 3), "float32"))
+        w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), "float32"))
+        w2 = relax.Var("w2", R.Tensor((16, 32, 3, 3), "float32"))
+        w3 = relax.Var("w3", R.Tensor((16, 32, 3, 3), "float32"))
+        with bb.function("main", [x, w0, w1, w2, w3]):
+            with bb.dataflow():
+                lv0 = bb.emit(relax.Call(fused_conv2d_relu1, (x, w0)))
+                lv1 = bb.emit(relax.Call(fused_conv2d_relu1, (x, w1)))
+                lv2 = bb.emit_te(topi.concatenate, (lv0, lv1), axis=1)
+                lv3 = bb.emit(relax.Call(fused_conv2d_relu2, (lv2, w2)))
+                lv4 = bb.emit(relax.Call(fused_conv2d_relu2, (lv2, w3)))
+                gv = bb.emit_output(bb.call_te(topi.concatenate, (lv3, lv4), 
axis=1))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_fuse_parallel_injective():
+    def before():
+        bb = relax.BlockBuilder()
+
+        x = relax.Var("x", R.Tensor((10, 20), "int32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, relax.const(1, "int32"))
+                lv1 = bb.emit_te(topi.squeeze, lv0)
+                lv2 = bb.emit_te(topi.transpose, lv0, axes=[1, 0])
+                lv3 = bb.emit_te(topi.transpose, lv2, axes=[1, 0])
+                gv = bb.emit_output(bb.call_te(topi.left_shift, lv1, lv3))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected():
+        bb = relax.BlockBuilder()
+
+        # Grouped function
+        x = relax.Var("x", R.Tensor((10, 20), "int32"))
+        p0 = relax.Var("p0", R.Tensor((), "int32"))
+        with bb.function(
+            "fused_add_squeeze_transpose_transpose1_left_shift", [x, p0], 
attrs={"Primitive": 1}
+        ):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, p0)
+                lv1 = bb.emit_te(topi.squeeze, lv0)
+                lv2 = bb.emit_te(topi.transpose, lv0, axes=[1, 0])
+                lv3 = bb.emit_te(topi.transpose, lv2, axes=[1, 0], 
primfunc_name_hint="transpose1")
+                gv = bb.emit_output(bb.call_te(topi.left_shift, lv1, lv3))
+            bb.emit_func_output(gv)
+
+        # Get the global variables of the grouped functions
+        fused_func = 
bb.get().get_global_var("fused_add_squeeze_transpose_transpose1_left_shift")
+
+        # Main function
+        x = relax.Var("x", R.Tensor((10, 20), "int32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                gv = bb.emit_output(relax.Call(fused_func, (x, relax.const(1, 
"int32"))))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_softmax():
+    """Test if softmax can be fused with following ops."""
+
+    def before():
+        bb = relax.BlockBuilder()
+
+        x = relax.Var("x", R.Tensor((16, 16), "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.nn.softmax, x)
+                gv = bb.emit_output(bb.call_te(topi.cast, lv0, 
dtype="float16"))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected():
+        bb = relax.BlockBuilder()
+
+        # Grouped function
+        x = relax.Var("x", R.Tensor((16, 16), "float32"))
+        with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.nn.softmax, x)
+                gv = bb.emit_output(bb.call_te(topi.cast, lv0, 
dtype="float16"))
+            bb.emit_func_output(gv)
+
+        # Get the global variables of the grouped functions
+        fused_func = bb.get().get_global_var("fused_softmax_cast")
+
+        # Main function
+        x = relax.Var("x", R.Tensor((16, 16), "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                gv = bb.emit_output(relax.Call(fused_func, (x,)))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    _check(before(), expected())
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fuse_tir.py 
b/tests/python/relax/test_transform_fuse_tir.py
new file mode 100644
index 0000000000..91edab2bbb
--- /dev/null
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -0,0 +1,563 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import relax, topi
+from tvm.script import relax as R
+
+
+def _check(mod_before, mod_expected):
+    mod = relax.transform.FuseTIR()(mod_before)
+    tvm.ir.assert_structural_equal(mod, mod_expected)
+
+
+def test_simple():
+    def before():
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        p0 = relax.Var("p0", R.Tensor([], "float32"))
+
+        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 
True}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, p0)
+                lv1 = bb.emit_te(topi.exp, lv0)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+            bb.emit_func_output(gv)
+        fused_add_exp_squeeze = 
bb.get().get_global_var("fused_add_exp_squeeze")
+
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        p0 = relax.Var("p0", R.Tensor([], "float32"))
+        with bb.function("main", [x, p0]):
+            with bb.dataflow():
+                gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0]))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected():
+        def fused_add_exp_squeeze(x, p0):
+            add = topi.add(x, p0)
+            exp = topi.exp(add)
+            squeeze = topi.squeeze(exp)
+            return squeeze
+
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        p0 = relax.Var("p0", R.Tensor([], "float32"))
+        with bb.function("main", [x, p0]):
+            with bb.dataflow():
+                gv = bb.emit_output(bb.call_te(fused_add_exp_squeeze, x, p0))
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_conv2d_fuse():
+    def before(dtype):
+        bb = relax.BlockBuilder()
+
+        # Grouped function 1
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+        w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype))
+        p0 = relax.Var("p0", R.Tensor((), dtype))
+        with bb.function("fused_conv2d_add1_add2", [x, w, p0], 
attrs={"Primitive": True}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(
+                    topi.nn.conv2d,
+                    x,
+                    w,
+                    strides=1,
+                    padding=1,
+                    dilation=1,
+                    primfunc_name_hint="conv2d",
+                )
+                lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1")
+                gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, 
primfunc_name_hint="add2"))
+            bb.emit_func_output(gv)
+
+        # Grouped function 2
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+        w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype))
+        y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype))
+        with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": 
True}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(
+                    topi.nn.conv2d,
+                    x,
+                    w,
+                    strides=1,
+                    padding=0,
+                    dilation=1,
+                    primfunc_name_hint="conv2d1",
+                )
+                gv = bb.emit_output(bb.call_te(topi.add, lv0, y, 
primfunc_name_hint="add2"))
+            bb.emit_func_output(gv)
+
+        # Get the global variables of the grouped functions
+        mod = bb.get()
+        fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2")
+        fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2")
+
+        # Main function
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+        w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
+        w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
+        w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
+        with bb.function("main", [x, w1, w2, w3]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
+                lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, 
relax.const(1, dtype)]))
+                lv2 = bb.emit_te(
+                    topi.nn.conv2d,
+                    lv1,
+                    w3,
+                    strides=1,
+                    padding=1,
+                    dilation=1,
+                )
+                gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2, 
lv2]))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected(dtype):
+        def fused_conv2d_add1_add2(x, w, p):
+            conv = topi.nn.conv2d(x, w, strides=1, padding=1, dilation=1)
+            add = topi.add(p, conv)
+            return topi.add(conv, add)
+
+        def fused_conv2d1_add2(x, w, p):
+            conv = topi.nn.conv2d(x, w, strides=1, padding=0, dilation=1)
+            return topi.add(conv, p)
+
+        bb = relax.BlockBuilder()
+
+        # Main function
+        x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+        w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
+        w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
+        w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
+        with bb.function("main", [x, w1, w2, w3]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
+                lv1 = bb.emit_te(fused_conv2d_add1_add2, lv0, w1, 
relax.const(1, dtype))
+                lv2 = bb.emit_te(
+                    topi.nn.conv2d,
+                    lv1,
+                    w3,
+                    strides=1,
+                    padding=1,
+                    dilation=1,
+                )
+                gv = bb.emit_output(bb.call_te(fused_conv2d1_add2, lv1, w2, 
lv2))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    _check(before("float32"), expected("float32"))
+
+
+def test_two_subfunction():
+    def before():
+        bb = relax.BlockBuilder()
+        x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
+        with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}):
+            with bb.dataflow():
+                lv1 = bb.emit_te(topi.exp, x1)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+            bb.emit_func_output(gv)
+        mod = bb.get()
+
+        func_gv = mod.get_global_var("fused_exp_squeeze")
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv = bb.emit(relax.Call(func_gv, [x]))
+                lv2 = bb.emit(relax.Call(func_gv, [lv]))
+                gv = bb.emit_output(lv2)
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    def expected():
+        def fused_exp_squeeze(x):
+            exp = topi.exp(x)
+            squeeze = topi.squeeze(exp)
+            return squeeze
+
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv = bb.emit_te(fused_exp_squeeze, x)
+                lv2 = bb.emit_te(fused_exp_squeeze, lv)
+                gv = bb.emit_output(lv2)
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_fuse_same_primfunc():
+    def before():
+        bb = relax.BlockBuilder()
+        x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
+        with bb.function("fused_exp_exp_squeeze", [x1], attrs={"Primitive": 
True}):
+            with bb.dataflow():
+                lv1 = bb.emit_te(topi.exp, x1)
+                lv2 = bb.emit_te(topi.exp, lv1)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv2))
+            bb.emit_func_output(gv)
+        mod = bb.get()
+
+        func_gv = mod.get_global_var("fused_exp_exp_squeeze")
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv = bb.emit(relax.Call(func_gv, [x]))
+                gv = bb.emit_output(lv)
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    def expected():
+        def fused_exp_exp_squeeze(x):
+            exp = topi.exp(x)
+            exp = topi.exp(exp)
+            squeeze = topi.squeeze(exp)
+            return squeeze
+
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv = bb.emit_te(fused_exp_exp_squeeze, x)
+                gv = bb.emit_output(lv)
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_fuse_with_tuple_as_param():
+    def before():
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], 
"float32")]))
+        with bb.function("fused_exp_add", [x], attrs={"Primitive": True}):
+            with bb.dataflow():
+                lv0 = bb.emit(relax.TupleGetItem(x, 0))
+                lv1 = bb.emit(relax.TupleGetItem(x, 1))
+                lv2 = bb.emit_te(topi.exp, lv0)
+                gv = bb.emit_output(bb.call_te(topi.add, lv2, lv1))
+            bb.emit_func_output(gv)
+        mod = bb.get()
+
+        func_gv = mod.get_global_var("fused_exp_add")
+        x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], 
"float32")]))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                gv = bb.emit_output(relax.Call(func_gv, [x]))
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    def expected():
+        def fused_exp_add(x1, x2):
+            exp = topi.exp(x1)
+            return topi.add(exp, x2)
+
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10], 
"float32")]))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit(relax.TupleGetItem(x, 0))
+                lv1 = bb.emit(relax.TupleGetItem(x, 1))
+                gv = bb.emit_output(bb.call_te(fused_exp_add, lv0, lv1))
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_fuse_with_nested_tuple_as_param():
+    tuple_struct_info = R.Tuple(
+        [R.Tensor([10], "float32"), R.Tuple([R.Tensor([10], "float32"), 
R.Tensor([10], "float32")])]
+    )
+
+    def before():
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", tuple_struct_info)
+        with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}):
+            with bb.dataflow():
+                lv0 = bb.emit(relax.TupleGetItem(x, 0))
+                lv0_exp = bb.emit_te(topi.exp, lv0)
+                lv1 = bb.emit(relax.TupleGetItem(x, 1))
+                lv1_0 = bb.emit(relax.TupleGetItem(lv1, 0))
+                lv1_1 = bb.emit(relax.TupleGetItem(lv1, 1))
+                lv2 = bb.emit_te(topi.add, lv1_0, lv1_1)
+                gv = bb.emit_output(bb.call_te(topi.add, lv0_exp, lv2))
+            bb.emit_func_output(gv)
+        mod = bb.get()
+
+        func_gv = mod.get_global_var("fused_exp_add_add")
+        x = relax.Var("x", tuple_struct_info)
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                gv = bb.emit_output(relax.Call(func_gv, [x]))
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    def expected():
+        def fused_exp_add_add(x1, x2, x3):
+            exp = topi.exp(x1)
+            add = topi.add(x2, x3)
+            return topi.add(exp, add)
+
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", tuple_struct_info)
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit(relax.TupleGetItem(x, 0))
+                lv1 = bb.emit(relax.TupleGetItem(x, 1))
+                lv2 = bb.emit(relax.TupleGetItem(lv1, 0))
+                lv3 = bb.emit(relax.TupleGetItem(lv1, 1))
+                gv = bb.emit_output(bb.call_te(fused_exp_add_add, lv0, lv2, 
lv3))
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_fuse_with_call_tir_in_main():
+    def before():
+        bb = relax.BlockBuilder()
+        x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
+        with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}):
+            with bb.dataflow():
+                lv = bb.emit_te(topi.exp, x1)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv))
+            bb.emit_func_output(gv)
+        mod = bb.get()
+
+        func_gv = mod.get_global_var("fused_exp_squeeze")
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit(relax.Call(func_gv, [x]))
+                lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32"))
+                gv = bb.emit_output(lv1)
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    def expected():
+        def fused_exp_squeeze(x):
+            exp = topi.exp(x)
+            squeeze = topi.squeeze(exp)
+            return squeeze
+
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv = bb.emit_te(fused_exp_squeeze, x)
+                lv2 = bb.emit_te(topi.add, lv, relax.const(1, "float32"))
+                gv = bb.emit_output(lv2)
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_fuse_with_const_in_argument():
+    def before():
+        bb = relax.BlockBuilder()
+        x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
+        x2 = relax.Var("x2", R.Tensor([], "float32"))
+        with bb.function("fused_add_exp_squeeze", [x1, x2], 
attrs={"Primitive": True}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x1, x2)
+                lv1 = bb.emit_te(topi.exp, lv0)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+            bb.emit_func_output(gv)
+        mod = bb.get()
+
+        func_gv = mod.get_global_var("fused_add_exp_squeeze")
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv = bb.emit(relax.Call(func_gv, [x, relax.const(1, 
"float32")]))
+                gv = bb.emit_output(lv)
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    def expected():
+        def fused_add_exp_squeeze(x, y):
+            add = topi.add(x, y)
+            exp = topi.exp(add)
+            squeeze = topi.squeeze(exp)
+            return squeeze
+
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("main", [x]):
+            with bb.dataflow():
+                lv = bb.emit_te(fused_add_exp_squeeze, x, relax.const(1, 
"float32"))
+                gv = bb.emit_output(lv)
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_fuse_tuple_output():
+    def before():
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        p0 = relax.Var("p0", R.Tensor([], "float32"))
+
+        with bb.function("fused_add_exp", [x, p0], attrs={"Primitive": True}):
+            with bb.dataflow():
+                gv0 = bb.emit_output(bb.call_te(topi.add, x, p0))
+                gv1 = bb.emit_output(bb.call_te(topi.exp, gv0))
+            bb.emit_func_output(relax.Tuple([gv0, gv1]))
+        fused_add_exp = bb.get().get_global_var("fused_add_exp")
+
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        p0 = relax.Var("p0", R.Tensor([], "float32"))
+        with bb.function("main", [x, p0]):
+            with bb.dataflow():
+                gv = bb.emit_output(relax.Call(fused_add_exp, [x, p0]))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected():
+        def fused_add_exp(x, p0):
+            add = topi.add(x, p0)
+            exp = topi.exp(add)
+            return add, exp
+
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        p0 = relax.Var("p0", R.Tensor([], "float32"))
+        with bb.function("main", [x, p0]):
+            with bb.dataflow():
+                gv = bb.emit_output(bb.call_te(fused_add_exp, x, p0))
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_fuse_with_immediate_tuple():
+    def before():
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        y = relax.Var("y", R.Tensor([10, 20], "float32"))
+
+        with bb.function("fused_add", [x, y], attrs={"Primitive": True}):
+            with bb.dataflow():
+                lv_tuple = bb.emit(relax.Tuple([x, relax.Tuple([x, y])]))
+                lv_x = bb.emit(relax.TupleGetItem(lv_tuple, 0))
+                lv0 = bb.emit(relax.TupleGetItem(lv_tuple, 1))
+                lv_y = bb.emit(relax.TupleGetItem(lv0, 1))
+                gv = bb.emit_output(bb.call_te(topi.add, lv_x, lv_y))
+            bb.emit_func_output(gv)
+        fused_add = bb.get().get_global_var("fused_add")
+
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        y = relax.Var("y", R.Tensor([10, 20], "float32"))
+        with bb.function("main", [x, y]):
+            with bb.dataflow():
+                gv = bb.emit_output(relax.Call(fused_add, [x, y]))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected():
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        y = relax.Var("y", R.Tensor([10, 20], "float32"))
+        with bb.function("main", [x, y]):
+            with bb.dataflow():
+                gv = bb.emit_output(bb.call_te(topi.add, x, y, 
primfunc_name_hint="fused_add"))
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_fuse_return_partial_result():
+    def te_argmax_idx_val(val):
+        from tvm import te
+
+        def f_combine(x, y):
+            lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
+            rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
+            return lhs, rhs
+
+        def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType):
+            return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1)
+
+        argmax = te.comm_reducer(f_combine, f_identity, name="argmax")
+        m, n = val.shape
+        k = te.reduce_axis((0, n), "k")
+        max_idx, max_val = te.compute(
+            (m,), lambda i: argmax((k.var, val[i, k]), axis=k), name="argmax"
+        )
+        return max_idx, max_val
+
+    def before():
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        offset = relax.Var("offset", R.Tensor([10], "int32"))
+        with bb.function("fused_argmax_add", [x, offset], attrs={"Primitive": 
True}):
+            with bb.dataflow():
+                lv = bb.emit_te(te_argmax_idx_val, x)
+                idx = bb.emit(relax.TupleGetItem(lv, 0))
+                gv = bb.emit_output(bb.call_te(topi.add, idx, offset))
+            bb.emit_func_output(gv)
+        mod = bb.get()
+
+        func_gv = mod.get_global_var("fused_argmax_add")
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        offset = relax.Var("x", R.Tensor([10], "int32"))
+        with bb.function("main", [x, offset]):
+            with bb.dataflow():
+                gv = bb.emit_output(relax.Call(func_gv, [x, offset]))
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    def expected():
+        def fused_argmax_add(x, offset):
+            idx, value = te_argmax_idx_val(x)
+            idx = topi.add(idx, offset)
+            return idx
+
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        offset = relax.Var("offset", R.Tensor([10], "int32"))
+        with bb.function("main", [x, offset]):
+            with bb.dataflow():
+                gv = bb.emit_output(bb.call_te(fused_argmax_add, x, offset))
+            bb.emit_func_output(gv)
+        return bb.get()
+
+    _check(before(), expected())
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index f6d2e4c20e..6e9e14d3dc 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1073,5 +1073,4 @@ def test_class_normalize():
 
 
 if __name__ == "__main__":
-    test_cross_function_call()
     tvm.testing.main()


Reply via email to