This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 85b8a419ab [Unity][Pass] Operator Fusion Passes (#14001)
85b8a419ab is described below
commit 85b8a419ab7a7979367bc9c85642a7c9a2df6e54
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Feb 15 21:35:14 2023 +0800
[Unity][Pass] Operator Fusion Passes (#14001)
[Unity][Pass] Operator fusion passes
This PR introduces three passes for operator fusion:
1. AnnotateTIROpPattern: analysis the operator kind from PrimFunc.
2. FuseOps: fuse operators for Relax functions, which adds a new fused
relax primitive function.
3. FuseTIR: fuse corresponding TIR PrimFuncs for the fused relax.
---
include/tvm/relax/analysis.h | 11 +
include/tvm/tir/buffer.h | 14 +-
python/tvm/relax/transform/transform.py | 43 +
src/relax/transform/annotate_tir_op_pattern.cc | 55 ++
src/relax/transform/fuse_ops.cc | 909 +++++++++++++++++++++
src/relax/transform/fuse_tir.cc | 728 +++++++++++++++++
.../test_transform_annotate_tir_op_pattern.py | 360 ++++++++
tests/python/relax/test_transform_fuse_ops.py | 759 +++++++++++++++++
tests/python/relax/test_transform_fuse_tir.py | 563 +++++++++++++
tests/python/relax/test_tvmscript_parser.py | 1 -
10 files changed, 3441 insertions(+), 2 deletions(-)
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index 24cfe5b9bf..a55fe6797d 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -260,6 +260,17 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const
StructInfo& derived,
TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
arith::Analyzer* ana = nullptr);
+/*!
+ * \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax
FuseOps.
+ *
+ * \param func The PrimFunc to be analyzed.
+ * \return The Op Pattern Kind.
+ *
+ * \note This analysis applies on TIR function but is primarily used by relax
passes.
+ * As a result we place it under the relax namespace.
+ */
+TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func);
+
/*!
* \brief Check if the given PrimFunc is essentially doing a reshape operation.
* The reshape operation also includes expand_dims, squeeze, flatten, etc.
diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index d7a2aec0b9..e3a853e4c7 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -34,6 +34,18 @@
namespace tvm {
namespace tir {
+#ifndef TVM_INDEX_DEFAULT_I64
+#define TVM_INDEX_DEFAULT_I64 1
+#endif
+/*! \brief if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return
int32 */
+inline DataType DefaultIndexType() {
+#if TVM_INDEX_DEFAULT_I64
+ return DataType::Int(64);
+#else
+ return DataType::Int(32);
+#endif
+}
+
// forward declare Stmt
class Stmt;
@@ -135,7 +147,7 @@ class BufferNode : public Object {
/*! \return preferred index type for this buffer node */
DataType DefaultIndexType() const {
- return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
+ return shape.size() != 0 ? shape[0].dtype() : tvm::tir::DefaultIndexType();
}
/*! \brief Determine the offset in the buffer of the given index.
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index cab18797c6..0f973db290 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -105,6 +105,49 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass:
return _ffi_api.AttachGlobalSymbol() # type: ignore
+def AnnotateTIROpPattern() -> tvm.ir.transform.Pass:
+ """Annotate Op Pattern Kind for TIR functions
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ return _ffi_api.AnnotateTIROpPattern() # type: ignore
+
+
+def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass:
+ """This pass groups bindings in a dataflow block of Relax functions and
generate a new grouped
+ Relax function for each group, according to the fusion algorithm described
in the pass
+ implementation. By grouping bindings into new Relax functions, we
substitute the bindings in
+ the function being manipulated into function calls to the new grouped
function.
+
+ A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each
grouped function.
+
+ Parameters
+ ----------
+ fuse_opt_level : int
+ The level of fuse optimization. -1 indicates that the level will be
+ inferred from pass context.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass for operator fusion.
+ """
+ return _ffi_api.FuseOps(fuse_opt_level) # type: ignore
+
+
+def FuseTIR() -> tvm.ir.transform.Pass:
+ """Fuse primitive relax function into a larger TIR function if possible
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass for tir fusion.
+ """
+ return _ffi_api.FuseTIR() # type: ignore
+
+
def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""
diff --git a/src/relax/transform/annotate_tir_op_pattern.cc
b/src/relax/transform/annotate_tir_op_pattern.cc
new file mode 100644
index 0000000000..b1c1ed29af
--- /dev/null
+++ b/src/relax/transform/annotate_tir_op_pattern.cc
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/annotate_tir_op_pattern.cc
+ * \brief Annotate Op Pattern for TIR functions. It is a pass works on TIR
PrimFuncs,
+ * but they are needed for relax fusion. So we put them in the relax
namespace.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace relax {
+
+tir::PrimFunc AnnotateOpPattern(tir::PrimFunc f) {
+ if (f->HasNonzeroAttr("op_pattern")) {
+ return f;
+ } else {
+ relay::OpPatternKind kind = AnalyzeOpPatternKind(f);
+ return WithAttr(std::move(f), "op_pattern",
Integer(static_cast<int>(kind)));
+ }
+}
+
+namespace transform {
+
+Pass AnnotateTIROpPattern() {
+ auto pass_func = [=](tir::PrimFunc f, IRModule m, PassContext ctx) {
+ return AnnotateOpPattern(std::move(f));
+ };
+ return tir::transform::CreatePrimFuncPass(pass_func, 0,
"AnnotateTIROpPattern", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
new file mode 100644
index 0000000000..f3559b72da
--- /dev/null
+++ b/src/relax/transform/fuse_ops.cc
@@ -0,0 +1,909 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/fuse_ops.cc
+ * \brief This file contains a pass which groups bindings in a dataflow block
of Relax
+ * functions and generate a new grouped Relax function for each group,
according to the fusion
+ * algorithm described below. By grouping bindings into new Relax functions,
we substitute the
+ * bindings in the function being manipulated into function calls to the new
grouped function.
+ *
+ * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each
grouped function.
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/function.h>
+
+#include <optional>
+
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+
+namespace tvm {
+namespace relax {
+
+/*
+ Note on Fusing algorithm:
+
+ The main challenge of general fusor is to handle possible diamond shape
branches,
+ in the following graph, conv2d can be fused to elemwise add.
+
+ conv2d
+ / | \
+ / | \
+ op op op
+ \ | /
+ \ | /
+ elemwise add
+ |
+
+ However, at the point of conv2d we do not necessarily know that all the
future paths
+ will merge at the elemwise add. The fusion algorithm applies post-dominator
analysis.
+
+ The immediate post-dominator of a node defined by the closest node where all
the future path goes
+ into. In the above case, the elemwise add is the post-dominator of conv2d.
The general algorithm
+ is as follows:
+
+ - Construct a DAG of dataflow graph for dominator analysis
+ - Construct a post-dominator tree which gives immediate post dominator of
each node.
+ - Run fusion algorithm with the given post-dominator information.
+
+ Note that, because we run analysis on a DAG, we use a single pass
post-dominator
+ tree construction algorithm via LCA, which is simpler than the full version
that handles cycles.
+
+ The fusion algorithm traverses from each node and checks if it can be fused
to its
+ immediate post dominator. It has to check the following things:
+
+ - CheckPath: check all the path between a node and its immediate
post-dominator
+ satisfies the fuse condition.
+ - Note that these intermediate node can already be fused with another nodes,
the algorithm
+ will still run correctly.
+ - CommitFuse: mark all the nodes between source and post-dominator as the
same group.
+ - We use an Union-Find data structure to manage the groups.
+*/
+
+using relay::GraphPartitioner;
+using relay::IndexedForwardGraph;
+using relay::OpPatternKind;
+using support::LinkNode;
+
+constexpr uint32_t kMaxFusedOps = 256;
+
+TVM_REGISTER_PASS_CONFIG_OPTION("relax.FuseOps.max_depth", Integer);
+
+class GraphCreator : public ExprVisitor {
+ public:
+ /*!
+ * \brief Create a IndexedForwardGraph according to the input module. The
graph will be used for
+ * graph partition and operator fusion.
+ * \param mod The module which the creation accords to
+ * \param arena The allocator of all the internal node objects
+ * \return The created IndexedForwardGraph
+ */
+ static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) {
+ // Since cross-function call is not supported yet, FuseOps only serves the
entry function, whose
+ // name is "main".
+ auto relax_func = Downcast<Function>(mod->Lookup("main"));
+ GraphCreator creator(mod, arena);
+ creator(relax_func);
+
+ // The algorithm of the graph creator ensures that each created node will
be added to the
+ // post-dfs order and will be set its op pattern. Thus we check whether
all these containers
+ // have the same size.
+ size_t n_nodes = creator.graph_.node_map.size();
+ ICHECK_EQ(n_nodes, creator.graph_.post_dfs_order.size());
+ ICHECK_EQ(n_nodes, creator.initialized_nodes_.size());
+
+ return creator.graph_;
+ }
+
+ private:
+ explicit GraphCreator(IRModule mod, support::Arena* arena)
+ : mod_(std::move(mod)), arena_(arena) {}
+
+ void VisitExpr_(const FunctionNode* func) final {
+ for (const Var& param : func->params) {
+ IndexedForwardGraph::Node* param_node = CreateNode(param.get());
+ // The parameter is passed in from the outside, and thus it's marked as
an external reference,
+ // and it's pattern is `kOpaque`.
+ MarkAsExternRef(param_node);
+ SetNodePattern(param_node, OpPatternKind::kOpaque);
+ AddToPostDFSOrder(param_node, param.get());
+ }
+ ExprVisitor::VisitExpr_(func);
+ }
+
+ void VisitBindingBlock(const BindingBlock& block) final {
+ if (const auto* df_block = block.as<DataflowBlockNode>()) {
+ VisitBindingBlock_(df_block);
+ }
+ // We skip ordinary binding blocks since they might be impure (with side
effect or control flow)
+ }
+
+ // TODO(tvm-team): how to deal with MatchCast binding here
+
+ void VisitBinding_(const VarBindingNode* binding) final {
+ IndexedForwardGraph::Node* node = CreateNode(binding->var.get());
+
+ // If the variable is not a dataflow variable, it must be the output
variable of this dataflow
+ // block
+ if (!binding->var->IsInstance<DataflowVarNode>()) {
+ this->MarkAsExternRef(node);
+ }
+ if (const auto* call = binding->value.as<CallNode>()) {
+ // Case 1. The expression is a CallNode
+ VisitCall(call, node);
+ } else if (const auto* tuple_get_item =
binding->value.as<TupleGetItemNode>()) {
+ // Case 2. The expression is a TupleGetItemNode
+ VisitTupleGetItem(tuple_get_item, node);
+ } else {
+ VisitUnsupportedNode(binding->value, node);
+ // Case 3. The type of the expression is not fusion-supported.
+ // In this case, we skip adding edges, adding an empty node into graph.
+ }
+ AddToPostDFSOrder(node, binding->var.get());
+ }
+
+ /********** Non-Leaf Expression Nodes **********/
+
+ void VisitCall(const CallNode* call, IndexedForwardGraph::Node*
binding_var_node) {
+ ICHECK_NOTNULL(binding_var_node);
+
+ static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+ OpPatternKind pattern = OpPatternKind::kOpaque;
+ Array<Expr> args = call->args;
+
+ // - If the op being called is a TIR PrimFunc, we get the function op
pattern directly from the
+ // function attribute and visit the arguments one by one.
+ // - Otherwise, the pattern of the current binding variable node is set to
`kOpaque`, and we
+ // recurse into the call expression.
+ const auto* op = call->op.as<OpNode>();
+ if (op == call_tir_op_.get()) {
+ const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
+ tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
+
+ // Override args for call_tir
+ args = Downcast<Tuple>(call->args[1])->fields;
+
+ // TODO(tvm-team): handle the shape argument (args[3])
+ Optional<Integer> opt_pattern = func->GetAttr<Integer>("op_pattern");
+ if (opt_pattern.defined()) {
+ pattern =
static_cast<OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
+ } else {
+ pattern = OpPatternKind::kOpaque;
+ }
+ }
+ // The pattern of the current binding variable node is set to the pattern
of this operator.
+ SetNodePattern(binding_var_node, pattern);
+ // Visit all call args
+ for (const Expr& arg : args) {
+ ICHECK(IsLeaf(arg));
+ VisitLeaf(arg, binding_var_node, pattern);
+ }
+ }
+
+ void VisitTupleGetItem(const TupleGetItemNode* tuple_item,
+ IndexedForwardGraph::Node* binding_var_node) {
+ ICHECK_NOTNULL(binding_var_node);
+
+ SetNodePattern(binding_var_node, OpPatternKind::kInjective);
+ VisitLeaf(tuple_item->tuple, binding_var_node, OpPatternKind::kInjective);
+ }
+
+ void VisitUnsupportedNode(const Expr& expr, IndexedForwardGraph::Node*
binding_var_node) {
+ ICHECK_NOTNULL(binding_var_node);
+ SetNodePattern(binding_var_node, OpPatternKind::kOpaque);
+
+ auto visit_leaves = [this, &binding_var_node](const Expr& e) {
+ if (e->IsInstance<VarNode>() || e->IsInstance<ConstantNode>()) {
+ VisitLeaf(e, binding_var_node, OpPatternKind::kOpaque);
+ }
+ };
+ PostOrderVisit(expr, visit_leaves);
+ }
+
+ /********** Leaf Expression Nodes **********/
+
+ void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node*
binding_var_node,
+ const OpPatternKind& pattern) {
+ ICHECK_NOTNULL(binding_var_node);
+
+ // Recursive visit if it's Tuple
+ if (const auto* tuple = leaf_expr.as<TupleNode>()) {
+ for (const Expr& expr : tuple->fields) {
+ VisitLeaf(expr, binding_var_node, pattern);
+ }
+ return;
+ }
+
+ auto it = graph_.node_map.find(leaf_expr.get());
+ IndexedForwardGraph::Node* leaf_node = nullptr;
+ if (it != graph_.node_map.end()) {
+ leaf_node = it->second;
+ } else if (leaf_expr->IsInstance<ConstantNode>()) {
+ leaf_node = CreateNode(leaf_expr.get());
+ // Since we never fuse constants, the pattern of the constant is set to
`kOpaque`.
+ SetNodePattern(leaf_node, OpPatternKind::kOpaque);
+ AddToPostDFSOrder(leaf_node, leaf_expr.get());
+ } else {
+ LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got:
" << leaf_expr
+ << " used before definition.";
+ }
+ AddEdge(leaf_node, binding_var_node, pattern);
+ }
+
+ /********** Helper Functions **********/
+
+ /*!
+ * \brief Check whether the expression is a leaf expression
+ * \param expr The expression to be checked
+ * \return Whether the expression is a leaf expression
+ * \note In order to avoid too much refactor, this method is a simple
copy-paste of the is-leaf
+ * check in "block_builder.cc". And it should be refactored in the future.
+ * \sa src/relax/ir/block_builder.cc
+ */
+ static bool IsLeaf(const Expr& expr) {
+ // NOTE: Tuples are treated as leaf nodes for ergonomics
+ return expr.as<VarNode>() || expr.as<GlobalVarNode>() ||
expr.as<ConstantNode>() ||
+ expr.as<ShapeExprNode>() || expr.as<ExternFuncNode>() ||
expr.as<OpNode>() ||
+ expr.as<TupleNode>();
+ }
+
+ /*!
+ * \brief Create a graph node corresponding to the input key
+ * \param key The object which is used to create the graph node
+ * \return The created graph node
+ * \note The node corresponding to each key is supposed to be created for
only once
+ */
+ IndexedForwardGraph::Node* CreateNode(const Object* key) {
+ ICHECK(graph_.node_map.find(key) == graph_.node_map.end())
+ << "The node corresponding to the input key is not supposed to be
created before";
+ auto* node = arena_->make<IndexedForwardGraph::Node>();
+ graph_.node_map[key] = node;
+ return node;
+ }
+
+ /*!
+ * \brief Append the input node to the post-dfs order of the graph
+ * \param node The node to be appended
+ * \param key The key corresponding to the node
+ * \note Each node is supposed to be appended to the post-dfs order for only
once
+ */
+ void AddToPostDFSOrder(IndexedForwardGraph::Node* node, const Object* key) {
+ auto it = graph_.node_map.find(key);
+ ICHECK(it != graph_.node_map.end() && it->second == node)
+ << "The node must have been created before adding to the post-dfs
order";
+
+ // We only set the reference of the node when adding it to the post-dfs
order. Thus, if the
+ // reference of a node is already set, it must have been appended to the
post-dfs order.
+ ICHECK(node->ref == nullptr)
+ << "The node is not supposed to be added into the post-dfs order
before";
+
+ node->ref = key;
+ node->index = graph_.post_dfs_order.size();
+ graph_.post_dfs_order.push_back(node);
+ }
+
+ /*!
+ * \brief Add an edge from the input start to the input end in the graph,
with specific pattern
+ * \param start The start of the edge
+ * \param end The end of the edge
+ * \param pattern The pattern of this edge
+ */
+ void AddEdge(IndexedForwardGraph::Node* start, IndexedForwardGraph::Node*
end,
+ OpPatternKind pattern) {
+ auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge>>();
+ link->value.node = end;
+ link->value.pattern = pattern;
+ start->outputs.Push(link);
+ }
+
+ /*!
+ * \brief Mark a given node as "external reference", which means the node
cannot be fused as an
+ * intermediate node
+ * \param node The graph node to be marked
+ */
+ void MarkAsExternRef(IndexedForwardGraph::Node* node) { node->extern_ref =
true; }
+
+ /*!
+ * \brief Set the pattern of the input node
+ * \param node The graph node to be set
+ * \param pattern The pattern of the node
+ */
+ void SetNodePattern(IndexedForwardGraph::Node* node, OpPatternKind pattern) {
+ ICHECK(initialized_nodes_.find(node) == initialized_nodes_.end())
+ << "The input node is supposed to be set pattern for only once";
+ initialized_nodes_.insert(node);
+ node->pattern = pattern;
+ }
+
+ private:
+ /*! \brief The IRModule from which the indexed forward graph is created */
+ IRModule mod_;
+ /*! \brief The allocator of all the internal node objects */
+ support::Arena* arena_;
+ /*! \brief The created indexed forward graph */
+ IndexedForwardGraph graph_;
+ /*! \brief The graph nodes whose patterns are set */
+ std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
+};
+
+/*!
+ * \brief The ExprMutator used to create a new grouped function
+ * \details The workflow of this ExprMutator is:
+ * - The bindings in the function will be added by OperatorFusor via
`AppendBinding(...)`.
+ * - When adding a new binding through `AppendBinding(...)`, we check whether
the variables and
+ * constants used by the binding are defined by some previous added binding.
And for the undefined
+ * variables and constants, we add them to the argument list and created new
variables as the
+ * corresponding parameters.
+ * - When `CreateFunction()` is called, we go through each binding and update
the binding with the
+ * new parameters. After that we wrap all bindings with a DataflowBlock and a
Function.
+ */
+class FunctionCreator : public ExprMutator {
+ public:
+ explicit FunctionCreator(bool lift_constant) : lift_constant_(lift_constant)
{}
+ /*!
+ * \brief Append a new binding to this function and possibly create new
parameters for the
+ * function accordingly
+ * \param binding The binding to be appended
+ * \note Allowed bindings are:
+ * - VarBinding with value being a call node calling `relax.call_tir`.
+ * - VarBinding with value being a tuple-get-item node.
+ * // TODO(tvm-team): handle match shape
+ */
+ void AppendBinding(const Binding& binding) {
+ ICHECK(!function_.defined())
+ << "The `function_` is supposed to be uncreated when adding bindings";
+
+ if (const auto* var_binding = binding.as<VarBindingNode>()) {
+ if (const auto* call = var_binding->value.as<CallNode>()) {
+ if (call->op == Op::Get("relax.call_tir")) {
+ // Update the name of the function.
+ name_hint_ = name_hint_ + "_" +
Downcast<GlobalVar>(call->args[0])->name_hint;
+
+ const Tuple& args = Downcast<Tuple>(call->args[1]);
+ for (const Expr& arg : args->fields) {
+ CheckDefAndUpdateParam(arg);
+ }
+ // TODO(tvm-team): handle shape expr
+ } else {
+ if (call->op->IsInstance<OpNode>()) {
+ name_hint_ = name_hint_ + "_" + Downcast<Op>(call->op)->name;
+ } else if (call->op->IsInstance<GlobalVarNode>()) {
+ std::string gvar_name = Downcast<GlobalVar>(call->op)->name_hint;
+ if (auto pos = gvar_name.find("fused_"); pos == 0) {
+ name_hint_ = name_hint_ + "_" +
gvar_name.substr(std::string("fused_").size());
+ } else {
+ name_hint_ = name_hint_ + "_" + gvar_name;
+ }
+ }
+
+ for (const Expr& arg : call->args) {
+ CheckDefAndUpdateParam(arg);
+ }
+ }
+ } else {
+ const auto* tuple_item = var_binding->value.as<TupleGetItemNode>();
+ ICHECK(tuple_item != nullptr);
+ CheckDefAndUpdateParam(tuple_item->tuple);
+ }
+
+ // Mark the binding variable as defined.
+ defined_vars_.insert(var_binding->var.get());
+ // Set var as output true if the binding is not a dataflow variable
+ if (!var_binding->var->IsInstance<DataflowVarNode>()) {
+ AppendOutput(var_binding->var);
+ }
+ } else {
+ // TODO(tvm-team): handle match_cast
+ }
+ bindings_.push_back(binding);
+ }
+
+ /*! \brief Set a var defined in the group as output. */
+ size_t AppendOutput(const Var& var) {
+ ICHECK(defined_vars_.count(var.get()));
+ auto output_idx = GetOutputIndex(var);
+ if (output_idx) {
+ return *output_idx;
+ }
+ output_vars_.push_back(var.get());
+ return output_vars_.size() - 1;
+ }
+
+ /*!
+ * \brief Create the grouped function according according to the collected
bindings and parameters
+ * \param composite_name The name to identify the pattern this function is
created from, if any.
+ * It will become the value of the kComposite attribute of the created
function.
+ * \note The created function won't be returned immediately. It's stored in
the `function_` field.
+ */
+ void CreateFunction(Map<String, ObjectRef> group_attrs) {
+ // Step 1. Start constructing a new dataflow block.
+ builder_->BeginDataflowBlock();
+
+ // Step 2. Visit each binding and collect outputs one by one.
+ Array<Expr> outputs(output_vars_.size(), Expr());
+ for (const Binding& binding : bindings_) {
+ if (auto output_idx = GetOutputIndex(binding->var)) {
+ // Case 1. It is an output binding
+ // We only allow VarBinding as output.
+ const auto* var_binding = binding.as<VarBindingNode>();
+ ICHECK_NOTNULL(var_binding);
+ Var output_var = builder_->EmitOutput(VisitExpr(var_binding->value));
+ var_remap_[var_binding->var->vid] = output_var;
+ outputs.Set(*output_idx, output_var);
+ } else {
+ // Case 2. It is an internel binding, add it to the binding list.
+ VisitBinding(binding);
+ }
+ }
+
+ // Step 3. Finish constructing the new block.
+ BindingBlock new_block = builder_->EndBlock();
+ ICHECK(!outputs.empty()) << "At least one output is required.";
+ Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs);
+ body = builder_->Normalize(body);
+ body = builder_->Normalize(SeqExpr({new_block}, body));
+ group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
+ function_ = Function(/*params=*/params_, //
+ /*body=*/body, //
+ /*ret_struct_info=*/NullOpt, //
+ /*attrs=*/DictAttrs(group_attrs));
+ }
+
+ /*! \brief The original bindings of the function */
+ Array<Binding> bindings_;
+ /*! \brief The parameters of the function */
+ Array<Var> params_;
+ /*! \brief The arguments to call the function on the caller side */
+ Array<Expr> arguments_;
+ /*! \brief The name for the fused function */
+ String name_hint_ = "fused";
+ /*! \brief The constructed Relax function */
+ Function function_{nullptr};
+
+ private:
+ std::optional<size_t> GetOutputIndex(Var v) {
+ auto it = std::find(output_vars_.begin(), output_vars_.end(), v.get());
+ if (it != output_vars_.end()) {
+ return std::distance(output_vars_.begin(), it);
+ }
+ return std::nullopt;
+ }
+
+ /*!
+ * \brief Check whether the input expression is defined within this
function. If not, create a new
+ * parameter for the expression.
+ * \param expr The expression to be checked
+ */
+ void CheckDefAndUpdateParam(const Expr& expr) {
+ // If the expression has already served as an argument, no need to create
another one for it.
+ if (std::find(arguments_.begin(), arguments_.end(), expr) !=
arguments_.end()) {
+ return;
+ }
+
+ // If the expression is not a variable or is a undefined variable, it
should be populated as a
+ // parameter of the relax function.
+ const auto* var = expr.as<VarNode>();
+ if ((var == nullptr || defined_vars_.count(var) == 0) &&
+ (lift_constant_ || !expr->IsInstance<ConstantNode>())) {
+ String name{nullptr};
+ if (var != nullptr) {
+ name = var->name_hint();
+ } else {
+ name = String("param_" + std::to_string(n_param_for_const_++));
+ }
+
+ Var param(std::move(name), GetStructInfo(expr));
+ arguments_.push_back(expr);
+ params_.push_back(param);
+ }
+ }
+
+ Expr VisitExpr(const Expr& expr) final {
+ // If the expression serves as an argument, return its correspondng
parameter.
+ auto it = std::find(arguments_.begin(), arguments_.end(), expr);
+ if (it != arguments_.end()) {
+ return params_[it - arguments_.begin()];
+ }
+ // Otherwise, recurse into this expression.
+ return ExprMutator::VisitExpr(expr);
+ }
+
+ private:
+ /*! \brief The variables defined in this function */
+ std::unordered_set<const VarNode*> defined_vars_;
+ /*! \brief The number of parameters reserved for constants */
+ int n_param_for_const_ = 0;
+ /*! \brief The output vars */
+ std::vector<const VarNode*> output_vars_;
+ /*! \brief Whether or not to lift bound constants to parameters */
+ bool lift_constant_;
+};
+
+/*!
+ * \brief The ExprMutator used to fuse the operators in Relax functions
+ * \details Given the partition results on the indexed-forward graph, for each
group whose size is
+ * larger than one, we create a new grouped function for it, containing all
bindings in that group.
+ * And we substitute the bindings in a group with a single function call to
the newly created
+ * grouped function. The workflow of this ExprMutator is: for each dataflow
block,
+ * - we go through the bindings one by one. For each binding, if it is in a
group whose size is
+ * larger than one, we add the binding to the function of the group it is in
and update the
+ * parameters and arguments of that function;
+ * - then we finalize all the grouped functions by updating their bindings
using BlockBuilder;
+ * - lastly, we go through the bindings again and substitute the bindings in
a group with a single
+ * call to the corresponding grouped function.
+ *
+ * After transforming a Relax function, we update the function in the
IRModule. Besides, we add all
+ * newly created grouped function to the IRModule.
+ */
+class OperatorFusor : public ExprMutator {
+ public:
+ using Group = GraphPartitioner::Group;
+ using GroupMap = std::unordered_map<const Object*, Group*>;
+
+ OperatorFusor(IRModule mod, const GroupMap& obj2group, bool lift_constants =
true)
+ : ExprMutator(mod),
+ mod_(std::move(mod)),
+ obj2group_(obj2group),
+ lift_constants_(lift_constants) {}
+
+ /*!
+ * \brief Construct a new operator fusor. Given the indexed-forward graph
and the graph partition
+ * result on that graph, the constructor creates a mapping from each leaf
AST object
+ * (e.g. parameters, variables, constants) to the group of the node
corresponding to the object
+ * in the graph.
+ * \param mod The IRModule to be transformed
+ * \param graph The indexed-forward graph of the input IRModule
+ * \param groups The grouped result of the group partition on the input
indexed-forward graph.
+ */
+ OperatorFusor(IRModule mod, const IndexedForwardGraph& graph, const
std::vector<Group*>& groups,
+ bool lift_constant = true)
+ : OperatorFusor(mod, CreateGroupMap(graph, groups), lift_constant) {}
+
+ /*!
+ * \brief The main transformation on the IRModule
+ * \return The new IRModule after transformation
+ */
+ IRModule Transform() {
+ for (const auto& [gv, func] : mod_->functions) {
+ // Only visit Relax function without attr kPrimitive.
+ if (func->IsInstance<relax::FunctionNode>() &&
!func->HasNonzeroAttr(attr::kPrimitive)) {
+ auto updated_func = Downcast<Function>(VisitExpr(func));
+ builder_->UpdateFunction(gv, updated_func);
+ }
+ }
+ return builder_->GetContextIRModule();
+ }
+
+ private:
+ static GroupMap CreateGroupMap(const IndexedForwardGraph& graph,
+ const std::vector<Group*>& groups) {
+ GroupMap obj2group;
+ for (int nid = 0; nid < static_cast<int>(graph.post_dfs_order.size());
++nid) {
+ Group* group_root = groups[nid]->FindRoot();
+ ICHECK(group_root != nullptr);
+ ICHECK(graph.post_dfs_order[nid]->ref != nullptr);
+ obj2group[graph.post_dfs_order[nid]->ref] = group_root;
+ }
+ return obj2group;
+ }
+
+ bool IsTupleOutput(Function f) {
+ auto sinfo = GetStructInfo(f).as<FuncStructInfoNode>();
+ ICHECK(sinfo);
+ return sinfo->ret->IsInstance<TupleStructInfoNode>();
+ }
+
+ BindingBlock VisitBindingBlock(const BindingBlock& block) final {
+ if (const auto* df_block = block.as<DataflowBlockNode>()) {
+ return VisitBindingBlock_(df_block);
+ }
+ // We skip ordinary binding blocks since they might be impure (with side
effect or control flow)
+ return block;
+ }
+
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
+ group2func_.clear();
+
+ // Step 1. Collect the bindings for each grouped function.
+ CollectFuncBindings(block->bindings);
+
+ // Step 2. Collect all group's boundary (i.e. the output vars for each
group)
+ CollectFuncBoundary(block->bindings);
+
+ // Step 3. Create the grouped function for each group.
+ for (auto& [g, creator] : group2func_) {
+ creator.CreateFunction(g->attrs);
+ }
+
+ // Step 4. Start generating the new binding block.
+ // - For groups with single binding, we directly recurse into the binding
and emit the new one.
+ // - For groups with multiple bindings, we emit the call to the grouped
function only when
+ // visiting the last binding of the group, because only by doing this we
don't break the
+ // dependencies among the bindings of different groups. And therefore, we
will skip all but the
+ // last binding of the group.
+ builder_->BeginDataflowBlock();
+
+ // For each group, record which variables need to be remapped to the
output of TupleGetItem.
+ // Only relevant when the output of the grouped function is a tuple.
+ std::unordered_map<Group*, std::vector<Var>> pending_tuple_get;
+
+ // A grouped function which returns a tuple requires attaching
TupleGetItem to each element and
+ // remapping variables in earlier bindings approriately. Thus, a binding
whose value depends on
+ // some elements of a tuple from other group's function must be emitted
after a call to the
+ // tuple-producing function is emitted and remapping is done.
+ // To guarantee this, we process bindings in the order of the topological
sort of the group
+ // dependency relations.
+ for (const auto& binding : TopoSortByGroupDep(block->bindings)) {
+ // Case 1. If the binding is the only binding in its group, recurse into
it and emit the
+ // transformed binding as usual.
+ Group* group = GetGroupFromBinding(binding);
+ if (group->num_nodes == 1 && group->attrs.empty()) {
+ VisitBinding(binding);
+ continue;
+ }
+
+ const auto& it_creator = group2func_.find(group);
+ ICHECK(it_creator != group2func_.end());
+ const FunctionCreator& func_info = it_creator->second;
+
+ // If this binding belongs to a group whose output is a tuple, the
original bound variable
+ // needs to be remapped to the output of TupleGetItem after the
corresponding tuple is
+ // emitted.
+ if (IsTupleOutput(func_info.function_) &&
tuple_get_indices_.count(binding->var.get())) {
+ pending_tuple_get[group].push_back(binding->var);
+ }
+
+ // Case 2. If the binding is not the last binding of the group, we skip
it.
+ if (!func_info.bindings_.back().same_as(binding)) {
+ continue;
+ }
+
+ // Case 3. The binding is the last binding of the group.
+ const auto* var_binding = binding.as<VarBindingNode>();
+ ICHECK(var_binding != nullptr) << "The last binding of a group whose
size is larger than 1 "
+ "is supposed to be a variable binding";
+
+ // Step a. Add the grouped function to the IRModule
+ GlobalVar gv = builder_->AddFunction(func_info.function_,
func_info.name_hint_);
+
+ // Step b. Create the call to the deduplicated function, and then emit
the call.
+ // - If this binding is an output binding, emit an output variable.
+ // - Otherwise, emit a dataflow variable.
+ Var new_var;
+ Call call_to_emit = Call(gv, UpdateArgs(func_info.arguments_));
+
+ if (var_binding->var->IsInstance<DataflowVarNode>()) {
+ new_var = builder_->Emit(call_to_emit);
+ } else {
+ new_var = builder_->EmitOutput(call_to_emit);
+ }
+
+ // Step c. Update the mapping used for the remapping of the binding
variables.
+ if (IsTupleOutput(func_info.function_)) {
+ // If the output is a tuple, attach TupleGetItem to all tuple
elements, and
+ // remap variables approriately.
+ // The variables that need to be remapped and the corresponding tuple
indices are
+ // available in pending_tuple_get and tuple_get_indices_ respectively.
+ for (const auto& var : pending_tuple_get[group]) {
+ auto tuple_get = TupleGetItem(new_var,
tuple_get_indices_[var.get()]);
+ var_remap_[var->vid] = builder_->Emit(tuple_get);
+ }
+ } else {
+ var_remap_[var_binding->var->vid] = new_var;
+ }
+ }
+ // Step 5. Finish the binding block generation.
+ return builder_->EndBlock();
+ }
+
+ /*!
+ * \brief Collect the bindings for each grouped function and update the
information of the grouped
+ * function
+ * \param bindings The bindings to be collected
+ * \note The function update is done by `AppendBinding(...)`
+ */
+ void CollectFuncBindings(const Array<Binding>& bindings) {
+ for (const Binding& binding : bindings) {
+ // If the binding is the only binding in its group, there is no need to
create a new function.
+ Group* group = GetGroupFromBinding(binding);
+ if (group->num_nodes == 1 && group->attrs.empty()) {
+ continue;
+ }
+ // Add the binding to the grouped function it's in, and update the
function information
+ // accordingly.
+ if (!group2func_.count(group)) {
+ group2func_.emplace(group, lift_constants_);
+ }
+ group2func_.find(group)->second.AppendBinding(binding);
+ }
+ }
+
+ void CollectFuncBoundary(const Array<Binding>& bindings) {
+ for (const Binding& binding : bindings) {
+ // Step 1. Get current binding's group
+ Group* cur_group = GetGroupFromBinding(binding);
+
+ // Step 2. Collect all used vars in the binding value and update bondary.
+ // - If the var's group is same as the binding's, the var is defined in
the same group
+ // - If the var's group is different with the binding's, the var must be
the output from
+ // another group. Mark it to be the group output.
+ auto update_boundary = [this, binding, &cur_group](const Expr& e) {
+ if (e->IsInstance<VarNode>()) {
+ const Var& used_var = Downcast<Var>(e);
+ Group* producer_group = GetGroupFromVar(used_var);
+ // Only check those group defined before.
+ // Skip the vars from input or groups with single binding.
+ if (producer_group != cur_group) {
+ ICHECK(!group_deps_[producer_group].count(cur_group))
+ << "A cyclic dependency detected between the groups " <<
binding->var->name_hint()
+ << " and " << used_var->name_hint() << " are in.";
+ group_deps_[cur_group].insert(producer_group);
+ }
+
+ if (auto producer = group2func_.find(producer_group);
+ producer_group != cur_group && producer != group2func_.end()) {
+ auto output_index = producer->second.AppendOutput(used_var);
+ tuple_get_indices_[used_var.get()] = output_index;
+ }
+ }
+ };
+
+ if (const auto* var_binding = binding.as<VarBindingNode>()) {
+ PostOrderVisit(var_binding->value, update_boundary);
+ } else {
+ const auto* match_cast = binding.as<MatchCastNode>();
+ ICHECK_NOTNULL(match_cast);
+ PostOrderVisit(match_cast->value, update_boundary);
+ }
+ }
+ }
+
+ /*!
+ * \brief Get the group which the input binding is in
+ * \param binding The binding to be queried
+ * \return The pointer to the group which the input binding is in
+ */
+ Group* GetGroupFromBinding(const Binding& binding) {
+ Var var = binding->var;
+ return GetGroupFromVar(var);
+ }
+
+ /*!
+ * \brief Get the group which the input var is in
+ * \param Var The var to be queried
+ * \return The pointer to the group which the input var is in
+ */
+ Group* GetGroupFromVar(const Var& var) {
+ const auto& it_group = obj2group_.find(var.get());
+ ICHECK(it_group != obj2group_.end());
+ Group* group = it_group->second;
+ return group->FindRoot();
+ }
+
+ /*!
+ * \brief Update the pre-stored arguments according to the variable
remapping of the fusor, by
+ * recursing into each argument
+ * \param args The arguments to be updated
+ * \return The updated arguments
+ */
+ Array<Expr> UpdateArgs(const Array<Expr>& args) {
+ Array<Expr> new_args;
+ new_args.reserve(args.size());
+ for (const Expr& arg : args) {
+ new_args.push_back(VisitExpr(arg));
+ }
+ return new_args;
+ }
+
+ private:
+ // Topologically sort bindings according to the group dependency relations.
+ Array<Binding> TopoSortByGroupDep(const Array<Binding>& bindings) {
+ std::unordered_map<Group*, std::vector<Binding>> bindings_per_group;
+ // The order to visit groups should respect the original order of bindings
as much as possible.
+ std::vector<Group*> group_order;
+ for (const auto& binding : bindings) {
+ auto g = GetGroupFromBinding(binding);
+ group_order.push_back(g); // Duplication does not matter since each
group is visited once.
+ bindings_per_group[g].push_back(binding);
+ }
+
+ std::unordered_set<Group*> visited;
+
+ std::function<void(Group*, std::function<void(Group*)>)> dfs_visit;
+ dfs_visit = [this, &visited, &dfs_visit](Group* g, auto leaf_fun) {
+ if (!visited.count(g)) {
+ visited.insert(g);
+ for (auto dep : group_deps_[g]) {
+ dfs_visit(dep, leaf_fun);
+ }
+ leaf_fun(g);
+ }
+ };
+
+ Array<Binding> sorted;
+
+ for (auto g : group_order) {
+ dfs_visit(g, [&sorted, &bindings_per_group](Group* leaf) {
+ for (const auto& binding : bindings_per_group[leaf]) {
+ sorted.push_back(binding);
+ }
+ });
+ }
+
+ return sorted;
+ }
+
+ /*! \brief The IRModule. */
+ IRModule mod_;
+ /*! \brief Internal arena. */
+ support::Arena arena_;
+ /*! \brief The group assignment map. */
+ GroupMap obj2group_;
+ /*! \brief Internal function information map. */
+ std::unordered_map<Group*, FunctionCreator> group2func_;
+ /*! \brief Record the index for TupleGetItem if the variable needs to be
remapped to an output
+ * tuple element after fusion. */
+ std::unordered_map<const VarNode*, int> tuple_get_indices_;
+ /*! \brief A map from a group to its dependent groups, used to detect cyclic
dependencies. */
+ std::unordered_map<Group*, std::unordered_set<Group*>> group_deps_;
+ /*! \brief Whether or not to lift bound constants to parameters of the
grouped function. */
+ bool lift_constants_{true};
+};
+
+IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) {
+ support::Arena arena;
+
+ // Step 1. Create the indexed-forward graph according to the input IRModule.
+ IndexedForwardGraph graph = GraphCreator::Create(mod, &arena);
+
+ // Step 2. Partition the graph by applying the fusion algorithm.
+ std::vector<GraphPartitioner::Group*> groups =
+ GraphPartitioner(&arena, opt_level, max_fuse_depth).Partition(graph);
+
+ // Step 3. Transform the IRModule by fusing the operators in accordance with
the graph partition
+ // results.
+ return OperatorFusor(mod, graph, groups, /*lift_constants*/
true).Transform();
+}
+
+namespace transform {
+
+Pass FuseOps(int fuse_opt_level) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
+ [=](IRModule m, PassContext pc) {
+ int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
+ auto max_fuse_depth = pc->GetConfig("relax.FuseOps.max_depth",
Integer(kMaxFusedOps));
+ return relax::FuseOps(m, opt_level, max_fuse_depth.value().IntValue());
+ };
+ return CreateModulePass(/*pass_function=*/pass_func, //
+ /*opt_level=*/0, //
+ /*pass_name=*/"FuseOps", //
+ /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
new file mode 100644
index 0000000000..fa5c296d27
--- /dev/null
+++ b/src/relax/transform/fuse_tir.cc
@@ -0,0 +1,728 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace tir {
+
+// TODO(Siyuan): move it to somewhere under tir folder
+/*!
+ * \brief Substitute a given source buffer with a given target buffer in
statements or expressions.
+ */
+class FuseTIRBufferSubstitor : private StmtExprMutator {
+ public:
+ static Stmt Substitute(const Map<Buffer, Buffer>& buffer_map, Stmt stmt) {
+ return FuseTIRBufferSubstitor(buffer_map)(std::move(stmt));
+ }
+
+ private:
+ explicit FuseTIRBufferSubstitor(const Map<Buffer, Buffer>& buffer_map) {
+ for (const auto& kv : buffer_map) {
+ const Buffer& src = kv.first;
+ const Buffer& tgt = kv.second;
+ buffer_var_map_[src->data.get()] = tgt;
+ }
+ }
+
+ PrimExpr VisitExpr_(const VarNode* _op) final {
+ auto it = buffer_var_map_.find(_op);
+ if (it != buffer_var_map_.end()) {
+ return it->second->data;
+ } else {
+ return GetRef<PrimExpr>(_op);
+ }
+ }
+
+ PrimExpr VisitExpr_(const BufferLoadNode* _op) final {
+ BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_op));
+ auto it = buffer_var_map_.find(load->buffer->data.get());
+ if (it != buffer_var_map_.end()) {
+ auto n = make_object<BufferLoadNode>(*load.get());
+ n->buffer = it->second;
+ return BufferLoad(n);
+ } else {
+ return std::move(load);
+ }
+ }
+
+ Stmt VisitStmt_(const BufferStoreNode* _op) final {
+ BufferStore store =
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_op));
+ auto it = buffer_var_map_.find(store->buffer->data.get());
+ if (it != buffer_var_map_.end()) {
+ auto n = CopyOnWrite(store.get());
+ n->buffer = it->second;
+ return BufferStore(n);
+ } else {
+ return std::move(store);
+ }
+ }
+
+ PrimExpr VisitExpr_(const LoadNode* _op) final {
+ Load load = Downcast<Load>(StmtExprMutator::VisitExpr_(_op));
+ auto it = buffer_var_map_.find(load->buffer_var.get());
+ if (it != buffer_var_map_.end()) {
+ auto n = make_object<LoadNode>(*load.get());
+ n->buffer_var = it->second->data;
+ return Load(n);
+ } else {
+ return std::move(load);
+ }
+ }
+
+ Stmt VisitStmt_(const StoreNode* _op) final {
+ Store store = Downcast<Store>(StmtExprMutator::VisitStmt_(_op));
+ auto it = buffer_var_map_.find(store->buffer_var.get());
+ if (it != buffer_var_map_.end()) {
+ auto n = CopyOnWrite(store.get());
+ n->buffer_var = it->second->data;
+ return Store(n);
+ } else {
+ return std::move(store);
+ }
+ }
+
+ Stmt VisitStmt_(const BlockNode* _op) final {
+ Block block = Downcast<Block>(StmtMutator::VisitStmt_(_op));
+
+ // Define the mutation functions.
+ auto f_mutate_match_buffers = [this](const MatchBufferRegion&
match_buffer) {
+ const Buffer& src_buffer = match_buffer->source->buffer;
+ auto it = buffer_var_map_.find(src_buffer->data.get());
+ if (it != buffer_var_map_.end()) {
+ return MatchBufferRegion(match_buffer->buffer,
+ BufferRegion(it->second,
match_buffer->source->region));
+ } else {
+ return match_buffer;
+ }
+ };
+
+ auto f_mutate_read_write_region = [this](const BufferRegion&
buffer_region) {
+ auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
+ return it == buffer_var_map_.end() ? buffer_region
+ : BufferRegion(it->second,
buffer_region->region);
+ };
+
+ // Step 1. Mutate `match_buffers`.
+ Array<MatchBufferRegion> match_buffers =
+ MutateArray(block->match_buffers, f_mutate_match_buffers);
+ // Step 2. Mutate the read/write region.
+ Array<BufferRegion> reads = MutateArray(block->reads,
f_mutate_read_write_region);
+ Array<BufferRegion> writes = MutateArray(block->writes,
f_mutate_read_write_region);
+
+ reads = UnionAccessRegion(reads);
+ writes = UnionAccessRegion(writes);
+
+ if (reads.same_as(block->reads) && //
+ writes.same_as(block->writes) && //
+ match_buffers.same_as(block->match_buffers)) {
+ return std::move(block);
+ } else {
+ auto n = CopyOnWrite(block.get());
+ n->reads = std::move(reads);
+ n->writes = std::move(writes);
+ n->match_buffers = std::move(match_buffers);
+ return Block(n);
+ }
+ }
+
+ private:
+ /*! \brief Mapping from src buffer.data to tgt buffer. */
+ std::unordered_map<const tir::VarNode*, tir::Buffer> buffer_var_map_;
+ /*! \brief The structural equality checker */
+ StructuralEqual structural_equal_;
+
+ Array<tir::BufferRegion> UnionAccessRegion(const Array<BufferRegion>&
regions) const {
+ // For now we only allow Buffer access the same elements.
+ // e.g. `[A[vi, vj], A[vi, vj]]` is a legal pattern but need to union to
`A[vi, vj]`
+ // However, `A[vi, vj], A[vi, vj + 1]` is not allow for now.
+ // Note: the order of return region should remain the same as the first
occurance of the region
+ Array<BufferRegion> ret;
+ std::unordered_map<const BufferNode*, Region> buffer_region_set;
+
+ for (const BufferRegion& region : regions) {
+ auto it = buffer_region_set.find(region->buffer.get());
+ if (it == buffer_region_set.end()) {
+ ret.push_back(region);
+ buffer_region_set[region->buffer.get()] = region->region;
+ } else {
+ ICHECK(structural_equal_(region->region, it->second));
+ }
+ }
+
+ if (ret.size() == regions.size()) {
+ return regions;
+ } else {
+ return ret;
+ }
+ }
+};
+
+/*! \brief A mutator which detect block name duplication and deduplicate the
names. */
+class BlockNameDeduplicator : public tir::StmtMutator {
+ private:
+ Stmt VisitStmt_(const BlockNode* op) final {
+ Block block = Downcast<Block>(tir::StmtMutator::VisitStmt_(op));
+
+ String name = GetUniqueName(block->name_hint);
+
+ if (name == block->name_hint) {
+ return std::move(block);
+ } else {
+ ObjectPtr<BlockNode> n = CopyOnWrite(block.get());
+ n->name_hint = std::move(name);
+ return Stmt(n);
+ }
+ }
+
+ String GetUniqueName(const String& prefix) {
+ String unique_prefix = prefix;
+ auto it = name_count_.find(prefix);
+ while (name_count_.count(unique_prefix)) {
+ unique_prefix = prefix + "_" + std::to_string(++it->second);
+ }
+ name_count_[unique_prefix] = 0;
+ return unique_prefix;
+ }
+
+ // TODO(relax-team): It should detects the number suffix and do renaming
properly
+ // e.g. GetUniqueName("name1") should return "name2" instead of "name10".
+ /*! \brief The count map to make block name unique. */
+ std::unordered_map<String, int> name_count_;
+};
+
+} // namespace tir
+
+namespace relax {
+
+class FusedTIRConstructor : public ExprVisitor {
+ public:
+ /*!
+ * \brief Construct a fused TIR PrimFunc from a relax sub-function
+ * \param mod The IRModule
+ * \param gv The global var of relax subfunction to be fused into one
PrimFunc
+ * \return The fused TIR PrimFunc
+ */
+ static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) {
+ FusedTIRConstructor visitor(mod, gv->name_hint);
+ BaseFunc f = mod->Lookup(gv);
+ CHECK(f->IsInstance<relax::FunctionNode>())
+ << "Expected relax functions, but got: " << f->GetTypeKey();
+ CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive))
+ << "Expected a function with attr `kPrimitive`";
+ visitor(Downcast<relax::Function>(f));
+ return visitor.fused_tir_;
+ }
+
+ private:
+ explicit FusedTIRConstructor(const IRModule& mod, const String& func_name)
+ : mod_(mod), func_name_(func_name) {}
+
+ void VisitExpr_(const FunctionNode* func) final {
+ // Step 1. Create buffers for function params
+ for (const Var& relax_param : func->params) {
+ auto ret = CreateParamsAndBuffers(GetStructInfo(relax_param), //
+ relax_param->name_hint());
+ const Array<tir::Var>& params = ret.first;
+ const Array<tir::Buffer>& buffers = ret.second;
+ ICHECK_EQ(params.size(), buffers.size());
+ for (size_t i = 0; i < params.size(); ++i) {
+ func_info_.buffer_map.Set(params[i], buffers[i]);
+ func_info_.params.push_back(params[i]);
+ }
+ func_info_.expr2buffers.Set(relax_param, buffers);
+ }
+
+ // Step 2. Visit Function body and create intermediate buffers
+ ExprVisitor::VisitExpr_(func);
+
+ // Step 3. Create and remap buffers for function output
+ ICHECK(func->body->IsInstance<SeqExprNode>())
+ << "Function body is expected to be a SeqExpr, but got: " <<
func->body->GetTypeKey();
+ Expr body = Downcast<SeqExpr>(func->body)->body;
+ auto it = func_info_.expr2buffers.find(body);
+ ICHECK(it != func_info_.expr2buffers.end())
+ << "Fail to detect output buffers for function body";
+ const Array<tir::Buffer>& buffers = (*it).second;
+ for (size_t i = 0; i < buffers.size(); ++i) {
+ tir::Var param = tir::Var("p_output" + std::to_string(i),
PrimType(DataType::Handle()));
+ func_info_.buffer_map.Set(param, buffers[i]);
+ func_info_.params.push_back(param);
+ func_info_.output_buffers.insert(buffers[i].get());
+ }
+
+ // Step 4. Create PrimFunc
+ fused_tir_ = ConstructFunc();
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) final {
+ // Update expr2buffers by visiting values.
+ this->VisitExpr(binding->value);
+ auto it = func_info_.expr2buffers.find(binding->value);
+ if (it != func_info_.expr2buffers.end()) {
+ // assign binding var to the buffers of the value
+ func_info_.expr2buffers.Set(binding->var, (*it).second);
+ } else {
+ LOG(FATAL) << "Unsupported binding value: " << binding->value;
+ }
+ }
+
+ void VisitBinding_(const MatchCastNode* match_cast) final {
+ LOG(FATAL) << "MatchCast is unsupported in primitive functions";
+ }
+
+ void VisitExpr_(const CallNode* call) final {
+ ExprVisitor::VisitExpr_(call);
+ static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+ ICHECK(call->op == call_tir_op_)
+ << "Only call_tir is supported in primitive function, but got: " <<
GetRef<Expr>(call);
+
+ // Step 1. Get Global var and PrimFunc
+ GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+ Optional<tir::PrimFunc> prim_func_ = GetPrimFunc(gv);
+ ICHECK(prim_func_.defined()) << "Cannot find the prim_func of the call_tir
in the module: "
+ << gv;
+ // Step 2. Renew all vars/buffer definitions and blocks to avoid
duplication
+ tir::PrimFunc prim_func = tir::RenewDefs(prim_func_.value());
+
+ // Step 3. Check functions are all schedulable funcs. i.e. the body of
func is root block
+ // TODO(Siyuan): support un-schedulable functions.
+ ICHECK(prim_func->body->IsInstance<tir::BlockRealizeNode>())
+ << "Only schedulable functions (whose body is the root block) can be
fused";
+ const tir::BlockRealize& root_realize =
Downcast<tir::BlockRealize>(prim_func->body);
+ const tir::Block& root_block = root_realize->block;
+
+ // Step 4. Add all the original alloc_buffers and body to the fused
function.
+ func_info_.alloc_buffers.insert(func_info_.alloc_buffers.end(),
+ root_block->alloc_buffers.begin(),
+ root_block->alloc_buffers.end());
+ func_info_.bodies.push_back(root_block->body);
+
+ // Step 5. Map input arguments to buffer
+ MapInputBuffer(prim_func, call->args[1]);
+ size_t num_output_buffers = GetCallTIROutputSize(call);
+ AllocateIntermediateBuffer(GetRef<Expr>(call), prim_func,
num_output_buffers);
+ // Update fused func name
+ func_info_.global_name += "_" + gv->name_hint;
+ }
+
+ void VisitExpr_(const TupleGetItemNode* tuple_get_item) final {
+ ExprVisitor::VisitExpr_(tuple_get_item);
+ auto it = func_info_.expr2buffers.find(tuple_get_item->tuple);
+ if (it != func_info_.expr2buffers.end()) {
+ int begin_buf_idx = 0;
+ int end_buf_idx = 0;
+ const TupleType& tuple_type =
Downcast<TupleType>(tuple_get_item->tuple->checked_type());
+ for (int i = 0; i < tuple_get_item->index; ++i) {
+ begin_buf_idx += GetTotalTensorSize(tuple_type->fields[i]);
+ }
+ end_buf_idx = begin_buf_idx +
GetTotalTensorSize(tuple_type->fields[tuple_get_item->index]);
+ func_info_.expr2buffers.Set(
+ GetRef<Expr>(tuple_get_item),
+ {(*it).second.begin() + begin_buf_idx, (*it).second.begin() +
end_buf_idx});
+ }
+ }
+
+ void VisitExpr_(const TupleNode* tuple) final {
+ ExprVisitor::VisitExpr_(tuple);
+ Array<tir::Buffer> buffers;
+ for (const Expr& expr : tuple->fields) {
+ auto it = func_info_.expr2buffers.find(expr);
+ if (it != func_info_.expr2buffers.end()) {
+ buffers.insert(buffers.end(), (*it).second.begin(),
(*it).second.end());
+ }
+ }
+ if (!buffers.empty()) {
+ func_info_.expr2buffers.Set(GetRef<Expr>(tuple), buffers);
+ }
+ }
+
+ void VisitExpr_(const ConstantNode* op) final {
+ LOG(FATAL) << "Relax.Constant is not supported in primitive functions.";
+ }
+
+ /********** Helper Functions **********/
+
+ /*!
+ * \brief Pattern match op to a TIR function and look it up.
+ * \return The TIR function, or NullOpt if patter match fails.
+ */
+ Optional<tir::PrimFunc> GetPrimFunc(const GlobalVar& global_var) {
+ // NOTE: as check works for nullptr(returns null)
+ Optional<BaseFunc> base_func = mod_->functions.Get(global_var);
+ if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
+ return GetRef<tir::PrimFunc>(pfunc);
+ } else {
+ return NullOpt;
+ }
+ }
+
+ /*!
+ * \brief Get the number of outputs for a call_tir node.
+ * \return The number of outputs.
+ */
+ static size_t GetCallTIROutputSize(const CallNode* call) {
+ static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+ ICHECK(call->op.same_as(call_tir_op_));
+ ICHECK_EQ(call->sinfo_args.size(), 1);
+ if (const auto* tuple_sinfo =
call->sinfo_args[0].as<TupleStructInfoNode>()) {
+ return tuple_sinfo->fields.size();
+ } else {
+ return 1;
+ }
+ }
+
+ /*! \brief Map old TIR func param buffer to new buffer, and then update
`buffer_subst_map` */
+ void MapArgsToBuffer(const Array<Expr> args, const Array<tir::Buffer>&
buffers) {
+ size_t buffer_idx = 0;
+ for (const Expr& arg : args) {
+ if (const auto* v = arg.as<VarNode>()) {
+ auto it = func_info_.expr2buffers.find(GetRef<Var>(v));
+ // Substitute the buffer with the already allocated one if it is an
intermediate var
+ if (it != func_info_.expr2buffers.end()) {
+ for (const tir::Buffer& target_buffer : (*it).second) {
+ ICHECK_LT(buffer_idx, buffers.size());
+ const tir::Buffer& buffer = buffers[buffer_idx];
+ // TODO(relax-team): Add support for symbolic shape fusion
+ for (const PrimExpr& shape_expr : buffer->shape) {
+ ICHECK(shape_expr.as<IntImmNode>()) << "Only support constant
shape fusion for now";
+ }
+ func_info_.buffer_subst_map.Set(buffer, target_buffer);
+ buffer_idx++;
+ }
+ }
+ }
+ }
+ // Make sure every buffers are maped.
+ ICHECK_EQ(buffer_idx, buffers.size());
+ }
+
+ /*!
+ * \brief Update buffer mapping `func_info_.buffer_subst_map` for input args
+ * \param func The old TIR PrimFunc
+ * \param output_size The number of output params. All output params are at
the end of param list.
+ */
+ void MapInputBuffer(const tir::PrimFunc& func, const relax::Expr& args) {
+ Array<Expr> arg_list;
+ Array<tir::Buffer> buffer_list;
+ if (const auto* arg_tuple = args.as<TupleNode>()) {
+ arg_list = arg_tuple->fields;
+ } else {
+ arg_list = {args};
+ }
+
+ ICHECK_GE(func->params.size(), arg_list.size());
+ for (size_t i = 0; i < arg_list.size(); ++i) {
+ const tir::Var& param = func->params[i];
+ const tir::Buffer& buffer = func->buffer_map.at(param);
+ buffer_list.push_back(buffer);
+ }
+
+ MapArgsToBuffer(arg_list, buffer_list);
+ }
+
+ /*!
+ * \brief Allocate buffer(s) and update `func_info.expr2buffers` if the
PrimFunc output(s) are
+ * intermediate results.
+ * \param expr The relax Expr, which can be binding vars or binding values.
+ * \param func The old TIR PrimFunc
+ * \param output_size The number of output params. All output params are at
the end of param list.
+ */
+ void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func,
size_t output_size) {
+ size_t n = func->params.size();
+ ICHECK_GE(n, output_size);
+ // Allocate intermediate buffer
+ Array<tir::Buffer> alloc_buffers;
+ for (size_t i = 0; i < output_size; ++i) {
+ const tir::Var& param = func->params[n - output_size + i];
+ const tir::Buffer& buffer = func->buffer_map.at(param);
+ func_info_.alloc_buffers.push_back(buffer);
+ alloc_buffers.push_back(buffer);
+ }
+ // Update expr2buffers
+ func_info_.expr2buffers.Set(expr, alloc_buffers);
+ }
+
+ /*!
+ * \brief Create an TIR func params and buffers with specified relax type
and shape
+ * \param struct_info The struct info
+ * \param name_hint The name hint for params and buffers
+ * \param index The index used for unique name_hint if type is Tuple.
+ * -1 means no need to add postfix since the relax param is not
a Tuple.
+ * \return The created TIR func params and buffers
+ */
+ static std::pair<Array<tir::Var>, Array<tir::Buffer>> CreateParamsAndBuffers(
+ StructInfo struct_info, const String& name_hint, int index = -1) {
+ Array<tir::Var> params;
+ Array<tir::Buffer> buffers;
+ if (const auto* tensor = struct_info.as<TensorStructInfoNode>()) {
+ // Case 1. the relax param is a DynTensor, we directly create a tir var
and buffer
+ const auto* shape_expr = tensor->shape.as<ShapeExprNode>();
+ ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with
symbolic shape.";
+
+ String name = index == -1 ? name_hint : name_hint + "_" +
std::to_string(index);
+ DataType dtype = tensor->dtype;
+ tir::Buffer buffer = tir::decl_buffer(shape_expr->values, dtype, name);
+ // Differentiate buffer name and param name by adding prefix `v_` to
param
+ // Every symbol should be unique in TVMScript, and Buffer is used more
than param
+ // So we decide to make sure buffer names have better readability.
+ tir::Var param = tir::Var("p_" + name, PrimType(DataType::Handle()));
+ params.push_back(std::move(param));
+ buffers.push_back(std::move(buffer));
+ } else if (const auto* tuple = struct_info.as<TupleStructInfoNode>()) {
+ // Case 2. the relax param is a Tuple, we recursively visit each field
until it's a DynTensor
+ // Enable postfix
+ if (index == -1) index = 0;
+ for (size_t i = 0; i < tuple->fields.size(); ++i) {
+ auto ret = CreateParamsAndBuffers(tuple->fields[i], name_hint, index);
+ const Array<tir::Var>& ret_params = ret.first;
+ const Array<tir::Buffer>& ret_buffers = ret.second;
+ ICHECK_EQ(ret_params.size(), ret_buffers.size());
+ // Adding tuple field results to the end of params and buffers.
+ params.insert(params.end(), ret_params.begin(), ret_params.end());
+ buffers.insert(buffers.end(), ret_buffers.begin(), ret_buffers.end());
+ index += ret_params.size();
+ }
+ } else {
+ ICHECK(false) << "shapes are expected to be ShapeExprNode or TupleNode";
+ }
+ return std::make_pair(params, buffers);
+ }
+
+ /*!
+ * \brief Construct fused TIR func with collected FuseFuncInfo
+ * \return The fused TIR
+ */
+ tir::PrimFunc ConstructFunc() {
+ Map<String, ObjectRef> attr_map;
+ attr_map.Set("tir.noalias", tir::const_true());
+ ICHECK(func_info_.global_name != "fused");
+ // Remove output buffers from func_info_.alloc_buffers
+ Array<tir::Buffer> alloc_buffers;
+ for (const tir::Buffer& buf : func_info_.alloc_buffers) {
+ if (func_info_.output_buffers.count(buf.get()) == 0) {
+ alloc_buffers.push_back(buf);
+ }
+ }
+ tir::Stmt body =
tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies));
+ body =
tir::FuseTIRBufferSubstitor::Substitute(func_info_.buffer_subst_map, body);
+ body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt,
alloc_buffers);
+ body = tir::BlockRealize({}, Bool(true), Downcast<tir::Block>(body));
+ tir::PrimFunc func(func_info_.params, body, VoidType(),
func_info_.buffer_map,
+ DictAttrs(attr_map));
+ return func;
+ }
+
+ /*! \brief Get DynTensor numbers from recursive Tuples. */
+ static size_t GetTotalTensorSize(const Type& type) {
+ if (type.as<DynTensorTypeNode>()) {
+ return 1;
+ } else if (const auto* tuple_type = type.as<TupleTypeNode>()) {
+ size_t num = 0;
+ for (const Type& type : tuple_type->fields) {
+ num += GetTotalTensorSize(type);
+ }
+ return num;
+ } else {
+ LOG(FATAL) << "DynTensorType and TupleType are expect, but got: " <<
type;
+ return 0;
+ }
+ }
+
+ /********** Function Info **********/
+
+ /*! \brief auxiliary information for FuseTIR */
+ struct FuseFuncInfo {
+ /*! \brief The arguments for calling prim_func */
+ Array<Expr> arguments;
+ /*!
+ * \brief The map from each dataflow var (intermediate var) to the
corresponding buffers
+ * allocated in the fused func
+ */
+ Map<Expr, Array<tir::Buffer>> expr2buffers;
+ /*! \brief The buffers to allocate in the fused func*/
+ Array<tir::Buffer> alloc_buffers;
+ /*! \brief The bodies of the original funcs, which is also the body of the
fused func. */
+ Array<tir::Stmt> bodies;
+ /*! \brief The params of the fused function*/
+ Array<tir::Var> params;
+ /*!
+ * \brief The map from buffer in original functions to corresponding
buffer in the fused
+ * function
+ */
+ Map<tir::Buffer, tir::Buffer> buffer_subst_map;
+ /*! \brief The `buffer_map` in the fused function*/
+ Map<tir::Var, tir::Buffer> buffer_map;
+ /*! \brief The output buffers in the function buffer_map*/
+ std::unordered_set<const tir::BufferNode*> output_buffers;
+ /*! \brief The name of the fused function */
+ std::string global_name = "fused";
+ };
+
+ /*! \brief The IRModule */
+ const IRModule& mod_;
+ /*! \brief The name hint for the input func. */
+ String func_name_;
+ /*! \brief The helper info to fuse TIR prim_func */
+ FuseFuncInfo func_info_;
+ /*! \brief The tir function after fusion*/
+ tir::PrimFunc fused_tir_;
+};
+
+/*!
+ * \brief The helper class to fuse TIR functions and build a new module which
calls the fused TIR.
+ */
+class TIRFuseMutator : public ExprMutator {
+ public:
+ static IRModule Transform(const IRModule& mod) {
+ // Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty
block builder.
+ TIRFuseMutator mutator(mod);
+ // Step 1. Fuse all primitive relax functions, store the result in
`fused_tir_funcs_`
+ for (const auto& kv : mod->functions) {
+ const GlobalVar& gv = kv.first;
+ const BaseFunc& func = kv.second;
+ // Only fuse primitive relax functions
+ if (func->IsInstance<relax::FunctionNode>() &&
func->HasNonzeroAttr(attr::kPrimitive)) {
+ tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv);
+ mutator.fused_tir_funcs_.Set(gv, fused_tir);
+ }
+ }
+
+ // Step 2. Update all non-primitive relax functions and add it, with the
dependent function,
+ // into the new IRModule
+ for (const auto& kv : mod->functions) {
+ const GlobalVar& gv = kv.first;
+ const BaseFunc& func = kv.second;
+ if (func->IsInstance<relax::FunctionNode>() &&
!func->HasNonzeroAttr(attr::kPrimitive)) {
+ relax::Function update_func =
Downcast<Function>(mutator.VisitExpr(func));
+ mutator.builder_->AddFunction(update_func, gv->name_hint);
+ }
+ }
+ return mutator.builder_->GetContextIRModule();
+ }
+
+ private:
+ explicit TIRFuseMutator(const IRModule& mod) : mod_(mod) {}
+
+ using ExprMutator::VisitExpr_;
+
+ // Get shape from call tir
+ static Expr GetCallTIRShape(StructInfo sinfo) {
+ if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
+ Array<Expr> fields = tuple->fields.Map([&](StructInfo x) { return
GetCallTIRShape(x); });
+ return Tuple(fields);
+ } else {
+ auto* tensor = sinfo.as<TensorStructInfoNode>();
+ ICHECK(tensor) << "FuseTIR can only take tensor or tuple type";
+ auto* shape_expr = tensor->shape.as<ShapeExprNode>();
+ ICHECK(shape_expr) << "FuseTIR requires all intermediate values have
shape";
+ return GetRef<ShapeExpr>(shape_expr);
+ }
+ }
+
+ Expr VisitExpr_(const CallNode* op) final {
+ static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+ Call call =
Downcast<Call>(builder_->Normalize(ExprMutator::VisitExpr_(op)));
+
+ if (call->op->IsInstance<GlobalVarNode>()) {
+ // Case 1. It is a relax cross function call
+ GlobalVar old_gv = Downcast<GlobalVar>(call->op);
+ auto it = fused_tir_funcs_.find(old_gv);
+ if (it != fused_tir_funcs_.end()) {
+ const tir::PrimFunc& fused_tir = (*it).second;
+ // Case 1.1. It calls a primitive relax function, update the call into
a call_tir
+ GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir,
old_gv->name_hint);
+ // Step a. Flatten all args since call_tir does not support Tuple
value.
+ Array<Expr> arg_list;
+ for (const Expr& arg : call->args) {
+ Array<Expr> flattened = FlattenArg(arg);
+ arg_list.insert(arg_list.end(), flattened.begin(), flattened.end());
+ }
+ // Step b. Create call_tir
+ Array<Expr> call_args = {fused_tir_gv, Tuple(arg_list)};
+ return Call(call_tir_op_, call_args, call->attrs,
{GetStructInfo(call)});
+ } else {
+ // Case 1.2. The callee function is not primitive, nothing to do.
+ return call;
+ }
+ } else if (call->op == call_tir_op_) {
+ // Case 2. It is a call_tir, re-emit the PrimFunc.
+ GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+ tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+ GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
+ return Call(call->op, {new_gv, call->args[1]}, call->attrs,
call->sinfo_args, call->span);
+ } else {
+ // Case 3. CallNode in other types. Leave it as it is.
+ return call;
+ }
+ }
+
+ /********** Helper Functions **********/
+
+ /*! \brief Flatten the call args if it's Tuple by emitting `TupleGetItem`. */
+ Array<Expr> FlattenArg(const Expr& arg) {
+ if (const auto* tuple_sinfo = GetStructInfoAs<TupleStructInfoNode>(arg)) {
+ Array<Expr> arg_list;
+ for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) {
+ Expr new_arg = builder_->Emit(TupleGetItem(arg, i));
+ Array<Expr> flattened = FlattenArg(new_arg);
+ arg_list.insert(arg_list.end(), flattened.begin(), flattened.end());
+ }
+ return arg_list;
+ } else {
+ return {arg};
+ }
+ }
+
+ private:
+ /*! \brief The IRModule */
+ const IRModule& mod_;
+ /*! \brief The map from global var of primitive relax function to generated
prim func. */
+ Map<GlobalVar, tir::PrimFunc> fused_tir_funcs_;
+};
+
+IRModule FuseTIR(IRModule mod) {
+ mod = TIRFuseMutator::Transform(mod);
+ return mod;
+}
+
+namespace transform {
+
+Pass FuseTIR() {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
+ [=](IRModule m, PassContext pc) { return relax::FuseTIR(m); };
+ return CreateModulePass(/*pass_function=*/pass_func, //
+ /*opt_level=*/0, //
+ /*pass_name=*/"FuseTIR", //
+ /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.FuseTIR").set_body_typed(FuseTIR);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py
b/tests/python/relax/test_transform_annotate_tir_op_pattern.py
new file mode 100644
index 0000000000..73c6537869
--- /dev/null
+++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py
@@ -0,0 +1,360 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import enum
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax
+from tvm.script import tir as T
+
+
+class OpPatternKind(enum.IntEnum):
+ kElemWise = 0
+ kBroadcast = 1
+ kInjective = 2
+ kCommReduce = 3
+ kOutEWiseFusable = 4
+ kTuple = 7
+ kOpaque = 8
+
+
+def test_annotate_opkind_outewisefusable():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+ T.func_attr({"global_symbol": "tir_matmul"})
+ m = T.var("int32")
+ n = T.var("int32")
+ k = T.var("int32")
+ A = T.match_buffer(x, (m, n))
+ B = T.match_buffer(y, (n, k))
+ C = T.match_buffer(z, (m, k))
+
+ for i, j, k in T.grid(m, k, n):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+ mod = InputModule
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["tir_matmul"].attrs["op_pattern"] ==
OpPatternKind.kOutEWiseFusable
+
+
+def test_annotate_opkind_outewisefusable_int_var_signature():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def tir_matmul(x: T.handle, y: T.handle, z: T.handle, m: T.int64, n:
T.int64, k: T.int64):
+ T.func_attr({"global_symbol": "tir_matmul"})
+ A = T.match_buffer(x, (m, n))
+ B = T.match_buffer(y, (n, k))
+ C = T.match_buffer(z, (m, k))
+
+ for i, j, k in T.grid(m, k, n):
+ with T.block("matmul"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+ mod = InputModule
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["tir_matmul"].attrs["op_pattern"] ==
OpPatternKind.kOutEWiseFusable
+
+
+def test_annotate_opkind_reduce():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def sum(x: T.handle, y: T.handle) -> None:
+ T.func_attr({"global_symbol": "elemwise"})
+ A = T.match_buffer(x, (16, 16))
+ B = T.match_buffer(y, (16,))
+
+ for i, j in T.grid(16, 16):
+ with T.block("matmul"):
+ vi, vj = T.axis.remap("SR", [i, j])
+ with T.init():
+ B[vi] = 0.0
+ B[vi] += A[vi, vj]
+
+ mod = InputModule
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["sum"].attrs["op_pattern"] == OpPatternKind.kCommReduce
+
+
+def test_annotate_opkind_ewise():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def elemwise(x: T.handle, y: T.handle) -> None:
+ T.func_attr({"global_symbol": "elemwise"})
+ A = T.match_buffer(x, (16, 16))
+ B = T.match_buffer(y, (16, 16))
+
+ for i, j in T.grid(16, 16):
+ with T.block("matmul"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] + 1.0
+
+ mod = InputModule
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["elemwise"].attrs["op_pattern"] == OpPatternKind.kElemWise
+
+
+def test_annotate_opkind_broadcast():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def broadcast(x: T.handle, y: T.handle) -> None:
+ T.func_attr({"global_symbol": "elemwise"})
+ A = T.match_buffer(x, (16, 16))
+ B = T.match_buffer(y, (16, 16, 16, 16))
+
+ for i0, j0, i1, j1 in T.grid(16, 16, 16, 16):
+ with T.block("matmul"):
+ vi0, vj0, vi1, vj1 = T.axis.remap("SSSS", [i0, j0, i1, j1])
+ B[vi0, vj0, vi1, vj1] = A[vj0, vj1]
+
+ mod = InputModule
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["broadcast"].attrs["op_pattern"] == OpPatternKind.kBroadcast
+
+
+def test_annotate_opkind_injective():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def injective(x: T.handle, y: T.handle) -> None:
+ T.func_attr({"global_symbol": "elemwise"})
+ A = T.match_buffer(x, (4, 4, 4, 4))
+ B = T.match_buffer(y, (16, 16))
+
+ for i, j in T.grid(16, 16):
+ with T.block("matmul"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi // 4, vj // 4, vi % 4, vj % 4]
+
+ mod = InputModule
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["injective"].attrs["op_pattern"] == OpPatternKind.kInjective
+
+
+def test_annotate_opkind_bias_add():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def tir_bias_add(
+ A: T.Buffer((1, 1000), "float32"),
+ B: T.Buffer((1000,), "float32"),
+ C: T.Buffer((1, 1000), "float32"),
+ ) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True})
+ # body
+ # with T.block("root")
+ for i0, i1 in T.grid(1, 1000):
+ with T.block("T_add"):
+ ax0, ax1 = T.axis.remap("SS", [i0, i1])
+ T.reads(A[ax0, ax1], B[ax1])
+ T.writes(C[ax0, ax1])
+ C[ax0, ax1] = A[ax0, ax1] + B[ax1]
+
+ mod = InputModule
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["tir_bias_add"].attrs["op_pattern"] ==
OpPatternKind.kElemWise
+
+
+def test_annotate_opkind_add_broadcast_with_unit_shape():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def add_with_unit_dim_len_broadcast(
+ A: T.Buffer((1, 64, 112, 112), "float32"),
+ B: T.Buffer((64, 1, 1), "float32"),
+ C: T.Buffer((1, 64, 112, 112), "float32"),
+ ) -> None:
+ T.func_attr({"global_symbol": "add5", "tir.noalias": True})
+ for i0, i1, i2, i3 in T.grid(1, 64, 112, 112):
+ with T.block("T_add"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(A[ax0, ax1, ax2, ax3], B[ax1, 0, 0])
+ T.writes(C[ax0, ax1, ax2, ax3])
+ C[ax0, ax1, ax2, ax3] = A[ax0, ax1, ax2, ax3] + B[ax1, 0,
0]
+
+ mod = InputModule
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["add_with_unit_dim_len_broadcast"].attrs["op_pattern"] ==
OpPatternKind.kElemWise
+
+
+def test_annotate_opkind_add_zero_dim_element_wise():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def add_zero_dim(
+ A: T.Buffer((128,), "float32"),
+ B: T.Buffer((), "float32"),
+ C: T.Buffer((128,), "float32"),
+ ) -> None:
+ T.func_attr({"global_symbol": "add8", "tir.noalias": True})
+ for i0 in T.serial(128):
+ with T.block("T_add"):
+ ax0 = T.axis.spatial(128, i0)
+ T.reads(A[ax0], B[()])
+ T.writes(C[ax0])
+ C[ax0] = A[ax0] + B[()]
+
+ mod = InputModule
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["add_zero_dim"].attrs["op_pattern"] ==
OpPatternKind.kElemWise
+
+
+def test_annotate_opkind_pooling():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def max_pool2d(
+ rxplaceholder_1: T.Buffer((1, 64, 112, 112), "float32"),
+ tensor_1: T.Buffer((1, 64, 56, 56), "float32"),
+ ) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "max_pool2d", "T.noalias": True})
+ # body
+ # with T.block("root")
+ pad_temp_1 = T.alloc_buffer([1, 64, 114, 114], dtype="float32")
+ for i0, i1, i2, i3 in T.grid(1, 64, 114, 114):
+ with T.block("pad_temp"):
+ ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
+ T.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1])
+ T.writes(pad_temp_1[ax0, ax1, ax2, ax3])
+ pad_temp_1[ax0, ax1, ax2, ax3] = T.if_then_else(
+ 1 <= ax2 and ax2 < 113 and 1 <= ax3 and ax3 < 113,
+ rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1],
+ T.float32(-3.4028234663852886e38),
+ dtype="float32",
+ )
+ for i0, i1, i2, i3, i4, i5 in T.grid(1, 64, 56, 56, 3, 3):
+ with T.block("tensor"):
+ ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0,
i1, i2, i3, i4, i5])
+ T.reads(
+ tensor_1[ax0, ax1, ax2, ax3],
+ pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1],
+ )
+ T.writes(tensor_1[ax0, ax1, ax2, ax3])
+ with T.init():
+ tensor_1[ax0, ax1, ax2, ax3] =
T.float32(-3.4028234663852886e38)
+ tensor_1[ax0, ax1, ax2, ax3] = T.max(
+ tensor_1[ax0, ax1, ax2, ax3],
+ pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1],
+ )
+
+ mod = InputModule
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["max_pool2d"].attrs["op_pattern"] ==
OpPatternKind.kOutEWiseFusable
+
+
+def test_annotate_opkind_softmax():
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def softmax(
+ rxplaceholder_1: T.Buffer((16, 16), "float32"),
+ T_softmax_norm_1: T.Buffer((16, 16), "float32"),
+ ) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "softmax", "T.noalias": True})
+ # body
+ # with T.block("root")
+ T_softmax_maxelem_1 = T.alloc_buffer([16], dtype="float32")
+ T_softmax_exp_1 = T.alloc_buffer([16, 16], dtype="float32")
+ T_softmax_expsum_1 = T.alloc_buffer([16], dtype="float32")
+ for i0_7, i1_3 in T.grid(16, 16):
+ with T.block("T_softmax_maxelem"):
+ i0_8, k = T.axis.remap("SR", [i0_7, i1_3])
+ T.reads(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8,
k])
+ T.writes(T_softmax_maxelem_1[i0_8])
+ with T.init():
+ T_softmax_maxelem_1[i0_8] =
T.float32(-3.4028234663852886e38)
+ T_softmax_maxelem_1[i0_8] = T.max(
+ T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k]
+ )
+ for i0_9, i1_4 in T.grid(16, 16):
+ with T.block("T_softmax_exp"):
+ i0_10, i1_5 = T.axis.remap("SS", [i0_9, i1_4])
+ T.reads(rxplaceholder_1[i0_10, i1_5],
T_softmax_maxelem_1[i0_10])
+ T.writes(T_softmax_exp_1[i0_10, i1_5])
+ T_softmax_exp_1[i0_10, i1_5] = T.exp(
+ rxplaceholder_1[i0_10, i1_5] -
T_softmax_maxelem_1[i0_10], dtype="float32"
+ )
+ for i0_11, i1_6 in T.grid(16, 16):
+ with T.block("T_softmax_expsum"):
+ i0_12, k = T.axis.remap("SR", [i0_11, i1_6])
+ T.reads(T_softmax_expsum_1[i0_12], T_softmax_exp_1[i0_12,
k])
+ T.writes(T_softmax_expsum_1[i0_12])
+ with T.init():
+ T_softmax_expsum_1[i0_12] = T.float32(0)
+ T_softmax_expsum_1[i0_12] = (
+ T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k]
+ )
+ for i0_13, i1_7 in T.grid(16, 16):
+ with T.block("T_softmax_norm"):
+ i0_14, i1_8 = T.axis.remap("SS", [i0_13, i1_7])
+ T.reads(T_softmax_exp_1[i0_14, i1_8],
T_softmax_expsum_1[i0_14])
+ T.writes(T_softmax_norm_1[i0_14, i1_8])
+ T.block_attr({"axis": 1})
+ T_softmax_norm_1[i0_14, i1_8] = (
+ T_softmax_exp_1[i0_14, i1_8] /
T_softmax_expsum_1[i0_14]
+ )
+
+ mod = InputModule
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["softmax"].attrs["op_pattern"] ==
OpPatternKind.kOutEWiseFusable
+
+
+def test_multiple_bufer_stores_fallback():
+ @tvm.script.ir_module
+ class CumsumModule:
+ @T.prim_func
+ def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer(160,
"float32")):
+ rxplaceholder = T.match_buffer(
+ var_rxplaceholder, [10, 16], dtype="float32", offset_factor=1
+ )
+ with T.block("cumsum_generic"):
+ T.reads(rxplaceholder[0:10, 0:16])
+ T.writes(out_buf[0:160])
+ for fused in T.parallel(1):
+ out_buf[fused * 160] = rxplaceholder[fused * 160 // 16,
fused * 160 % 16]
+ for v_k in T.serial(159):
+ out_buf[fused * 160 + (v_k + 1)] = (
+ out_buf[fused * 160 + (v_k + 1 - 1)]
+ + rxplaceholder[
+ (fused * 160 + (v_k + 1)) // 16,
+ (fused * 160 + (v_k + 1)) % 16,
+ ]
+ )
+
+ mod = CumsumModule
+ new_mod = relax.transform.AnnotateTIROpPattern()(mod)
+ assert new_mod["cumsum"].attrs["op_pattern"] == OpPatternKind.kOpaque
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fuse_ops.py
b/tests/python/relax/test_transform_fuse_ops.py
new file mode 100644
index 0000000000..1a228bb268
--- /dev/null
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -0,0 +1,759 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import relax, topi
+from tvm.script import relax as R
+
+
+def _check(mod_actual, mod_expected):
+ mod_actual = relax.transform.AnnotateTIROpPattern()(mod_actual)
+ mod_actual = relax.transform.FuseOps()(mod_actual)
+ mod_expected = relax.transform.AnnotateTIROpPattern()(mod_expected)
+ tvm.ir.assert_structural_equal(mod_actual, mod_expected)
+
+
+def test_fuse_simple():
+ """Simple testcase."""
+
+ def before():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+ lv1 = bb.emit_te(topi.exp, lv0)
+ gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ p0 = relax.Var("p0", R.Tensor((), "float32"))
+
+ with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive":
1}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, p0)
+ lv1 = bb.emit_te(topi.exp, lv0)
+ gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+ bb.emit_func_output(gv)
+ fused_add_exp_squeeze =
bb.get().get_global_var("fused_add_exp_squeeze")
+
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ gv = bb.emit_output(
+ relax.Call(fused_add_exp_squeeze, [x, relax.const(1,
"float32")])
+ )
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_conv2d_fuse():
+ """Test fusion case of conv2d"""
+
+ def before(dtype):
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+ w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
+ w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
+ w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
+ with bb.function("main", [x, w1, w2, w3]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
+ lv1 = bb.emit_te(topi.nn.conv2d, lv0, w1, strides=1,
padding=1, dilation=1)
+ # this is the next dominator.
+ lv2 = bb.emit_te(topi.add, relax.const(1, dtype), lv1)
+ lv3 = bb.emit_te(topi.add, lv1, lv2)
+ # second path
+ lv4 = bb.emit_te(topi.nn.conv2d, lv3, w2, strides=1,
padding=0, dilation=1)
+ lv5 = bb.emit_te(topi.nn.conv2d, lv3, w3, strides=1,
padding=1, dilation=1)
+ gv = bb.emit_output(bb.call_te(topi.add, lv4, lv5))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected(dtype):
+ bb = relax.BlockBuilder()
+
+ # Grouped function 1
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+ w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype))
+ p0 = relax.Var("p0", R.Tensor((), dtype))
+ with bb.function("fused_conv2d_add1_add2", [x, w, p0],
attrs={"Primitive": 1}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(
+ topi.nn.conv2d,
+ x,
+ w,
+ strides=1,
+ padding=1,
+ dilation=1,
+ primfunc_name_hint="conv2d",
+ )
+ lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1")
+ gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1,
primfunc_name_hint="add2"))
+ bb.emit_func_output(gv)
+
+ # Grouped function 2
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+ w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype))
+ y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype))
+ with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive":
1}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(
+ topi.nn.conv2d,
+ x,
+ w,
+ strides=1,
+ padding=0,
+ dilation=1,
+ primfunc_name_hint="conv2d1",
+ )
+ gv = bb.emit_output(bb.call_te(topi.add, lv0, y,
primfunc_name_hint="add2"))
+ bb.emit_func_output(gv)
+
+ # Get the global variables of the grouped functions
+ mod = bb.get()
+ fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2")
+ fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2")
+
+ # Main function
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+ w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
+ w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
+ w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
+ with bb.function("main", [x, w1, w2, w3]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
+ lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1,
relax.const(1, dtype)]))
+ lv2 = bb.emit_te(
+ topi.nn.conv2d,
+ lv1,
+ w3,
+ strides=1,
+ padding=1,
+ dilation=1,
+ )
+ gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2,
lv2]))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ _check(before("float32"), expected("float32"))
+ _check(before("float16"), expected("float16"))
+ _check(before("int8"), expected("int8"))
+
+
+def test_concatenate():
+ """Test fusion case involving concat op and Tuple node"""
+
+ def before():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(
+ topi.nn.pool2d,
+ x,
+ kernel=(2, 2),
+ stride=(2, 2),
+ dilation=(1, 1),
+ padding=(0, 0, 0, 0),
+ pool_type="max",
+ )
+ lv1 = bb.emit_te(topi.nn.upsampling, lv0, scale_h=2.0,
scale_w=2.0)
+ lv2 = bb.emit_te(topi.concatenate, (lv1, x), axis=1)
+ gv = bb.emit_output(bb.call_te(topi.add, lv2, relax.const(1,
"float32")))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected():
+ bb = relax.BlockBuilder()
+
+ # Grouped function
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ w = relax.Var("w", R.Tensor((1, 16, 32, 32), "float32"))
+ p0 = relax.Var("p0", R.Tensor((), "float32"))
+ with bb.function("fused_upsampling_concatenate_add", [w, x, p0],
attrs={"Primitive": 1}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.nn.upsampling, w, scale_h=2.0,
scale_w=2.0)
+ lv1 = bb.emit_te(topi.concatenate, (lv0, x), axis=1)
+ gv = bb.emit_output(bb.call_te(topi.add, lv1, p0))
+ bb.emit_func_output(gv)
+
+ # Get the global variables of the grouped functions
+ fused_upsampling_concatenate_add = bb.get().get_global_var(
+ "fused_upsampling_concatenate_add"
+ )
+
+ # Main function
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(
+ topi.nn.pool2d,
+ x,
+ kernel=(2, 2),
+ stride=(2, 2),
+ dilation=(1, 1),
+ padding=(0, 0, 0, 0),
+ pool_type="max",
+ )
+ gv = bb.emit_output(
+ relax.Call(
+ fused_upsampling_concatenate_add, (lv0, x,
relax.const(1, "float32"))
+ )
+ )
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_tuple_root():
+ """Test fusion case where Tuple node is the root in its group"""
+
+ def before():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(
+ topi.nn.pool2d,
+ x,
+ kernel=(2, 2),
+ stride=(2, 2),
+ dilation=(1, 1),
+ padding=(0, 0, 0, 0),
+ pool_type="max",
+ )
+ lv1 = bb.emit_te(topi.nn.upsampling, lv0, scale_h=2.0,
scale_w=2.0)
+ gv = bb.emit_output((lv1, x))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ # The fusion is supposed to make no change.
+ _check(before(), before())
+
+
+def test_fuse_tuple_get_elemwise():
+ def before(dim: int):
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((1, dim), "float32"))
+ w = relax.Var("w", R.Tensor((3 * dim, dim), "float32"))
+ with bb.function("main", [x, w]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.nn.dense, x, w)
+ lv1 = bb.emit_te(topi.split, lv0, indices_or_sections=3,
axis=1)
+ lv2 = bb.emit(relax.TupleGetItem(lv1, 0))
+ lv3 = bb.emit_te(topi.sigmoid, lv2)
+ lv4 = bb.emit(relax.TupleGetItem(lv1, 1))
+ lv5 = bb.emit_te(topi.tanh, lv4)
+ lv6 = bb.emit(relax.TupleGetItem(lv1, 2))
+ lv7 = bb.emit_te(topi.exp, lv6)
+ lv8 = bb.emit_te(topi.multiply, lv5, lv7)
+ gv = bb.emit_output(bb.call_te(topi.add, lv3, lv8))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected(dim: int):
+ bb = relax.BlockBuilder()
+
+ # Grouped function
+ dense = relax.Var("dense", R.Tensor((1, 3 * dim), "float32"))
+ with bb.function(
+ "fused_split_sigmoid_tanh_exp_multiply_add", [dense],
attrs={"Primitive": 1}
+ ):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.split, dense, indices_or_sections=3,
axis=1)
+ lv1 = bb.emit(relax.TupleGetItem(lv0, 0))
+ lv2 = bb.emit_te(topi.sigmoid, lv1)
+ lv3 = bb.emit(relax.TupleGetItem(lv0, 1))
+ lv4 = bb.emit_te(topi.tanh, lv3)
+ lv5 = bb.emit(relax.TupleGetItem(lv0, 2))
+ lv6 = bb.emit_te(topi.exp, lv5)
+ lv7 = bb.emit_te(topi.multiply, lv4, lv6)
+ gv = bb.emit_output(bb.call_te(topi.add, lv2, lv7))
+ bb.emit_func_output(gv)
+
+ # Get the global variables of the grouped functions
+ fused_split_sigmoid_tanh_exp_multiply_add = bb.get().get_global_var(
+ "fused_split_sigmoid_tanh_exp_multiply_add"
+ )
+
+ # Main function
+ x = relax.Var("x", R.Tensor((1, dim), "float32"))
+ w = relax.Var("w", R.Tensor((3 * dim, dim), "float32"))
+ with bb.function("main", [x, w]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.nn.dense, x, w)
+ gv =
bb.emit_output(relax.Call(fused_split_sigmoid_tanh_exp_multiply_add, (lv0,)))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ dim = 10
+ _check(before(dim), expected(dim))
+
+
+def test_tuple_get_root():
+ def before(dim: int):
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((1, 3 * dim), "float32"))
+ w = relax.Var("w", R.Tensor((dim, dim), "float32"))
+ with bb.function("main", [x, w]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1)
+ lv1 = bb.emit(relax.TupleGetItem(lv0, 0))
+ gv = bb.emit_output(bb.call_te(topi.nn.dense, lv1, w))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected(dim: int):
+ bb = relax.BlockBuilder()
+
+ # Grouped function
+ x = relax.Var("x", R.Tensor((1, 3 * dim), "float32"))
+ with bb.function("fused_split", [x], attrs={"Primitive": 1}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1)
+ gv = bb.emit_output(relax.TupleGetItem(lv0, 0))
+ bb.emit_func_output(gv)
+
+ # Get the global variables of the grouped functions
+ fused_split = bb.get().get_global_var("fused_split")
+
+ # Main function
+ x = relax.Var("x", R.Tensor((1, 3 * dim), "float32"))
+ w = relax.Var("w", R.Tensor((dim, dim), "float32"))
+ with bb.function("main", [x, w]):
+ with bb.dataflow():
+ lv0 = bb.emit(relax.Call(fused_split, (x,)))
+ gv = bb.emit_output(bb.call_te(topi.nn.dense, lv0, w))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ dim = 10
+ _check(before(dim), expected(dim))
+
+
+def test_tuple_intermediate():
+ def before():
+ bb = relax.BlockBuilder()
+
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.squeeze, x)
+ lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32"))
+ lv2 = bb.emit_te(topi.squeeze, lv0)
+ lv3 = bb.emit_te(topi.add, lv2, relax.const(1, "float32"))
+ lv4 = bb.emit_te(topi.add, lv3, relax.const(1, "float32"))
+ lv5 = bb.emit_te(topi.add, lv0, relax.const(1, "float32"))
+ lv6 = bb.emit_te(topi.concatenate, (lv1, lv4, lv5), axis=1)
+ lv7 = bb.emit_te(topi.squeeze, lv6)
+ gv = bb.emit_output(bb.call_te(topi.add, lv7, relax.const(1,
"float32")))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected():
+ bb = relax.BlockBuilder()
+
+ # Grouped function
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ p0 = relax.Var("p0", R.Tensor((), "float32"))
+ p1 = relax.Var("p1", R.Tensor((), "float32"))
+ p2 = relax.Var("p2", R.Tensor((), "float32"))
+ p3 = relax.Var("p3", R.Tensor((), "float32"))
+ p4 = relax.Var("p4", R.Tensor((), "float32"))
+ with bb.function(
+ "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1",
+ [x, p0, p1, p2, p3, p4],
+ attrs={"Primitive": 1},
+ ):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.squeeze, x)
+ lv1 = bb.emit_te(topi.add, lv0, p0)
+ lv2 = bb.emit_te(topi.squeeze, lv0)
+ lv3 = bb.emit_te(topi.add, lv2, p1)
+ lv4 = bb.emit_te(topi.add, lv3, p2)
+ lv5 = bb.emit_te(topi.add, lv0, p3)
+ lv6 = bb.emit_te(topi.concatenate, (lv1, lv4, lv5), axis=1)
+ lv7 = bb.emit_te(topi.squeeze, lv6)
+ gv = bb.emit_output(bb.call_te(topi.add, lv7, p4))
+ bb.emit_func_output(gv)
+
+ # Get the global variables of the grouped functions
+ fused_func = bb.get().get_global_var(
+ "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1"
+ )
+
+ # Main func
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ gv = bb.emit_output(
+ relax.Call(
+ fused_func,
+ (
+ x,
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ ),
+ )
+ )
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_tuple_consecutive():
+ def before():
+ bb = relax.BlockBuilder()
+
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+ lv1 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+ lv2 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+ lv3 = bb.emit_te(topi.concatenate, (lv0, lv1, lv2), axis=1)
+ lv4 = bb.emit_te(topi.add, lv3, relax.const(1, "float32"))
+ lv5 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+ lv6 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+ lv7 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+ lv8 = bb.emit_te(topi.concatenate, (lv5, lv6, lv7), axis=1)
+ lv9 = bb.emit_te(topi.add, lv8, relax.const(1, "float32"))
+ lv10 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+ lv11 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+ lv12 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+ lv13 = bb.emit_te(topi.concatenate, (lv10, lv11, lv12), axis=1)
+ lv14 = bb.emit_te(topi.add, lv13, relax.const(1, "float32"))
+ lv15 = bb.emit_te(topi.concatenate, (lv4, lv9, lv14), axis=1)
+ lv16 = bb.emit_te(
+ topi.nn.pool2d,
+ lv15,
+ kernel=(2, 2),
+ stride=(2, 2),
+ dilation=(1, 1),
+ padding=(0, 0, 0, 0),
+ pool_type="max",
+ )
+ lv17 = bb.emit_te(topi.add, lv16, relax.const(1, "float32"))
+ lv18 = bb.emit_te(topi.add, lv17, relax.const(1, "float32"))
+ gv = bb.emit_output((lv17, lv18))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected():
+ bb = relax.BlockBuilder()
+
+ # Grouped function 1
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ p0 = relax.Var("p0", R.Tensor((), "float32"))
+ p1 = relax.Var("p1", R.Tensor((), "float32"))
+ p2 = relax.Var("p2", R.Tensor((), "float32"))
+ p3 = relax.Var("p3", R.Tensor((), "float32"))
+ p4 = relax.Var("p4", R.Tensor((), "float32"))
+ p5 = relax.Var("p5", R.Tensor((), "float32"))
+ p6 = relax.Var("p6", R.Tensor((), "float32"))
+ p7 = relax.Var("p7", R.Tensor((), "float32"))
+ p8 = relax.Var("p8", R.Tensor((), "float32"))
+ p9 = relax.Var("p9", R.Tensor((), "float32"))
+ p10 = relax.Var("p10", R.Tensor((), "float32"))
+ p11 = relax.Var("p11", R.Tensor((), "float32"))
+ with bb.function(
+
"fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1",
+ [x, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11],
+ attrs={"Primitive": 1},
+ ):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, p0)
+ lv1 = bb.emit_te(topi.add, x, p1)
+ lv2 = bb.emit_te(topi.add, x, p2)
+ lv3 = bb.emit_te(topi.concatenate, (lv0, lv1, lv2), axis=1)
+ lv4 = bb.emit_te(topi.add, lv3, p3)
+ lv5 = bb.emit_te(topi.add, x, p4)
+ lv6 = bb.emit_te(topi.add, x, p5)
+ lv7 = bb.emit_te(topi.add, x, p6)
+ lv8 = bb.emit_te(topi.concatenate, (lv5, lv6, lv7), axis=1)
+ lv9 = bb.emit_te(topi.add, lv8, p7)
+ lv10 = bb.emit_te(topi.add, x, p8)
+ lv11 = bb.emit_te(topi.add, x, p9)
+ lv12 = bb.emit_te(topi.add, x, p10)
+ lv13 = bb.emit_te(topi.concatenate, (lv10, lv11, lv12), axis=1)
+ lv14 = bb.emit_te(topi.add, lv13, p11)
+ gv = bb.emit_output(bb.call_te(topi.concatenate, (lv4, lv9,
lv14), axis=1))
+ bb.emit_func_output(gv)
+
+ # Grouped function 2
+ concat = relax.Var("concat", R.Tensor((1, 144, 64, 64), "float32"))
+ p0 = relax.Var("p0", R.Tensor((), "float32"))
+ with bb.function("fused_pool2d_add2", [concat, p0],
attrs={"Primitive": 1}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(
+ topi.nn.pool2d,
+ concat,
+ kernel=(2, 2),
+ stride=(2, 2),
+ dilation=(1, 1),
+ padding=(0, 0, 0, 0),
+ pool_type="max",
+ )
+ gv = bb.emit_output(bb.call_te(topi.add, lv0, p0))
+ bb.emit_func_output(gv)
+
+ # Get the global variables of the grouped functions
+ mod = bb.get()
+ fused_func1 = mod.get_global_var(
+
"fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1"
+ )
+ fused_func2 = mod.get_global_var("fused_pool2d_add2")
+
+ # Main function
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit(
+ relax.Call(
+ fused_func1,
+ (
+ x,
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ relax.const(1, "float32"),
+ ),
+ )
+ )
+ lv1 = bb.emit(relax.Call(fused_func2, (lv0, relax.const(1,
"float32"))))
+ lv2 = bb.emit_te(topi.add, lv1, relax.const(1, "float32"))
+ gv = bb.emit_output((lv1, lv2))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_inception_like():
+ def before():
+ bb = relax.BlockBuilder()
+
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ w0 = relax.Var("w0", R.Tensor((16, 16, 3, 3), "float32"))
+ w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), "float32"))
+ w2 = relax.Var("w2", R.Tensor((16, 32, 3, 3), "float32"))
+ w3 = relax.Var("w3", R.Tensor((16, 32, 3, 3), "float32"))
+ with bb.function("main", [x, w0, w1, w2, w3]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.nn.conv2d, x, w0, strides=1, padding=1,
dilation=1)
+ lv1 = bb.emit_te(topi.nn.relu, lv0)
+ lv2 = bb.emit_te(topi.nn.conv2d, x, w1, strides=1, padding=1,
dilation=1)
+ lv3 = bb.emit_te(topi.nn.relu, lv2)
+ lv4 = bb.emit_te(topi.concatenate, (lv1, lv3), axis=1)
+ lv5 = bb.emit_te(topi.nn.conv2d, lv4, w2, strides=1,
padding=1, dilation=1)
+ lv6 = bb.emit_te(topi.nn.relu, lv5)
+ lv7 = bb.emit_te(topi.nn.conv2d, lv4, w3, strides=1,
padding=1, dilation=1)
+ lv8 = bb.emit_te(topi.nn.relu, lv7)
+ gv = bb.emit_output(bb.call_te(topi.concatenate, (lv6, lv8),
axis=1))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected():
+ bb = relax.BlockBuilder()
+
+ # Grouped function 1
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ w = relax.Var("w", R.Tensor((16, 16, 3, 3), "float32"))
+ with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(
+ topi.nn.conv2d,
+ x,
+ w,
+ strides=1,
+ padding=1,
+ dilation=1,
+ primfunc_name_hint="conv2d",
+ )
+ gv = bb.emit_output(bb.call_te(topi.nn.relu, lv0))
+ bb.emit_func_output(gv)
+
+ # Grouped function 2
+ x = relax.Var("x", R.Tensor((1, 32, 64, 64), "float32"))
+ w = relax.Var("w", R.Tensor((16, 32, 3, 3), "float32"))
+ with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(
+ topi.nn.conv2d,
+ x,
+ w,
+ strides=1,
+ padding=1,
+ dilation=1,
+ primfunc_name_hint="conv2d1",
+ )
+ gv = bb.emit_output(bb.call_te(topi.nn.relu, lv0))
+ bb.emit_func_output(gv)
+
+ # Get the global variables of the grouped functions
+ mod = bb.get()
+ fused_conv2d_relu1 = mod.get_global_var("fused_conv2d_relu")
+ fused_conv2d_relu2 = mod.get_global_var("fused_conv2d1_relu")
+
+ # Main function
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32"))
+ w0 = relax.Var("w0", R.Tensor((16, 16, 3, 3), "float32"))
+ w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), "float32"))
+ w2 = relax.Var("w2", R.Tensor((16, 32, 3, 3), "float32"))
+ w3 = relax.Var("w3", R.Tensor((16, 32, 3, 3), "float32"))
+ with bb.function("main", [x, w0, w1, w2, w3]):
+ with bb.dataflow():
+ lv0 = bb.emit(relax.Call(fused_conv2d_relu1, (x, w0)))
+ lv1 = bb.emit(relax.Call(fused_conv2d_relu1, (x, w1)))
+ lv2 = bb.emit_te(topi.concatenate, (lv0, lv1), axis=1)
+ lv3 = bb.emit(relax.Call(fused_conv2d_relu2, (lv2, w2)))
+ lv4 = bb.emit(relax.Call(fused_conv2d_relu2, (lv2, w3)))
+ gv = bb.emit_output(bb.call_te(topi.concatenate, (lv3, lv4),
axis=1))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_fuse_parallel_injective():
+ def before():
+ bb = relax.BlockBuilder()
+
+ x = relax.Var("x", R.Tensor((10, 20), "int32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, relax.const(1, "int32"))
+ lv1 = bb.emit_te(topi.squeeze, lv0)
+ lv2 = bb.emit_te(topi.transpose, lv0, axes=[1, 0])
+ lv3 = bb.emit_te(topi.transpose, lv2, axes=[1, 0])
+ gv = bb.emit_output(bb.call_te(topi.left_shift, lv1, lv3))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected():
+ bb = relax.BlockBuilder()
+
+ # Grouped function
+ x = relax.Var("x", R.Tensor((10, 20), "int32"))
+ p0 = relax.Var("p0", R.Tensor((), "int32"))
+ with bb.function(
+ "fused_add_squeeze_transpose_transpose1_left_shift", [x, p0],
attrs={"Primitive": 1}
+ ):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, p0)
+ lv1 = bb.emit_te(topi.squeeze, lv0)
+ lv2 = bb.emit_te(topi.transpose, lv0, axes=[1, 0])
+ lv3 = bb.emit_te(topi.transpose, lv2, axes=[1, 0],
primfunc_name_hint="transpose1")
+ gv = bb.emit_output(bb.call_te(topi.left_shift, lv1, lv3))
+ bb.emit_func_output(gv)
+
+ # Get the global variables of the grouped functions
+ fused_func =
bb.get().get_global_var("fused_add_squeeze_transpose_transpose1_left_shift")
+
+ # Main function
+ x = relax.Var("x", R.Tensor((10, 20), "int32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ gv = bb.emit_output(relax.Call(fused_func, (x, relax.const(1,
"int32"))))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_softmax():
+ """Test if softmax can be fused with following ops."""
+
+ def before():
+ bb = relax.BlockBuilder()
+
+ x = relax.Var("x", R.Tensor((16, 16), "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.nn.softmax, x)
+ gv = bb.emit_output(bb.call_te(topi.cast, lv0,
dtype="float16"))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected():
+ bb = relax.BlockBuilder()
+
+ # Grouped function
+ x = relax.Var("x", R.Tensor((16, 16), "float32"))
+ with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.nn.softmax, x)
+ gv = bb.emit_output(bb.call_te(topi.cast, lv0,
dtype="float16"))
+ bb.emit_func_output(gv)
+
+ # Get the global variables of the grouped functions
+ fused_func = bb.get().get_global_var("fused_softmax_cast")
+
+ # Main function
+ x = relax.Var("x", R.Tensor((16, 16), "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ gv = bb.emit_output(relax.Call(fused_func, (x,)))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ _check(before(), expected())
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fuse_tir.py
b/tests/python/relax/test_transform_fuse_tir.py
new file mode 100644
index 0000000000..91edab2bbb
--- /dev/null
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -0,0 +1,563 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm import relax, topi
+from tvm.script import relax as R
+
+
+def _check(mod_before, mod_expected):
+ mod = relax.transform.FuseTIR()(mod_before)
+ tvm.ir.assert_structural_equal(mod, mod_expected)
+
+
+def test_simple():
+ def before():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ p0 = relax.Var("p0", R.Tensor([], "float32"))
+
+ with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive":
True}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, p0)
+ lv1 = bb.emit_te(topi.exp, lv0)
+ gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+ bb.emit_func_output(gv)
+ fused_add_exp_squeeze =
bb.get().get_global_var("fused_add_exp_squeeze")
+
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ p0 = relax.Var("p0", R.Tensor([], "float32"))
+ with bb.function("main", [x, p0]):
+ with bb.dataflow():
+ gv = bb.emit_output(relax.Call(fused_add_exp_squeeze, [x, p0]))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected():
+ def fused_add_exp_squeeze(x, p0):
+ add = topi.add(x, p0)
+ exp = topi.exp(add)
+ squeeze = topi.squeeze(exp)
+ return squeeze
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ p0 = relax.Var("p0", R.Tensor([], "float32"))
+ with bb.function("main", [x, p0]):
+ with bb.dataflow():
+ gv = bb.emit_output(bb.call_te(fused_add_exp_squeeze, x, p0))
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_conv2d_fuse():
+ def before(dtype):
+ bb = relax.BlockBuilder()
+
+ # Grouped function 1
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+ w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype))
+ p0 = relax.Var("p0", R.Tensor((), dtype))
+ with bb.function("fused_conv2d_add1_add2", [x, w, p0],
attrs={"Primitive": True}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(
+ topi.nn.conv2d,
+ x,
+ w,
+ strides=1,
+ padding=1,
+ dilation=1,
+ primfunc_name_hint="conv2d",
+ )
+ lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1")
+ gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1,
primfunc_name_hint="add2"))
+ bb.emit_func_output(gv)
+
+ # Grouped function 2
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+ w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype))
+ y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype))
+ with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive":
True}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(
+ topi.nn.conv2d,
+ x,
+ w,
+ strides=1,
+ padding=0,
+ dilation=1,
+ primfunc_name_hint="conv2d1",
+ )
+ gv = bb.emit_output(bb.call_te(topi.add, lv0, y,
primfunc_name_hint="add2"))
+ bb.emit_func_output(gv)
+
+ # Get the global variables of the grouped functions
+ mod = bb.get()
+ fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2")
+ fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2")
+
+ # Main function
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+ w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
+ w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
+ w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
+ with bb.function("main", [x, w1, w2, w3]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
+ lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1,
relax.const(1, dtype)]))
+ lv2 = bb.emit_te(
+ topi.nn.conv2d,
+ lv1,
+ w3,
+ strides=1,
+ padding=1,
+ dilation=1,
+ )
+ gv = bb.emit_output(relax.Call(fused_conv2d1_add2, [lv1, w2,
lv2]))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected(dtype):
+ def fused_conv2d_add1_add2(x, w, p):
+ conv = topi.nn.conv2d(x, w, strides=1, padding=1, dilation=1)
+ add = topi.add(p, conv)
+ return topi.add(conv, add)
+
+ def fused_conv2d1_add2(x, w, p):
+ conv = topi.nn.conv2d(x, w, strides=1, padding=0, dilation=1)
+ return topi.add(conv, p)
+
+ bb = relax.BlockBuilder()
+
+ # Main function
+ x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
+ w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
+ w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
+ w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
+ with bb.function("main", [x, w1, w2, w3]):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
+ lv1 = bb.emit_te(fused_conv2d_add1_add2, lv0, w1,
relax.const(1, dtype))
+ lv2 = bb.emit_te(
+ topi.nn.conv2d,
+ lv1,
+ w3,
+ strides=1,
+ padding=1,
+ dilation=1,
+ )
+ gv = bb.emit_output(bb.call_te(fused_conv2d1_add2, lv1, w2,
lv2))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ _check(before("float32"), expected("float32"))
+
+
+def test_two_subfunction():
+ def before():
+ bb = relax.BlockBuilder()
+ x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
+ with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}):
+ with bb.dataflow():
+ lv1 = bb.emit_te(topi.exp, x1)
+ gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+ bb.emit_func_output(gv)
+ mod = bb.get()
+
+ func_gv = mod.get_global_var("fused_exp_squeeze")
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv = bb.emit(relax.Call(func_gv, [x]))
+ lv2 = bb.emit(relax.Call(func_gv, [lv]))
+ gv = bb.emit_output(lv2)
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ def expected():
+ def fused_exp_squeeze(x):
+ exp = topi.exp(x)
+ squeeze = topi.squeeze(exp)
+ return squeeze
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv = bb.emit_te(fused_exp_squeeze, x)
+ lv2 = bb.emit_te(fused_exp_squeeze, lv)
+ gv = bb.emit_output(lv2)
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_fuse_same_primfunc():
+ def before():
+ bb = relax.BlockBuilder()
+ x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
+ with bb.function("fused_exp_exp_squeeze", [x1], attrs={"Primitive":
True}):
+ with bb.dataflow():
+ lv1 = bb.emit_te(topi.exp, x1)
+ lv2 = bb.emit_te(topi.exp, lv1)
+ gv = bb.emit_output(bb.call_te(topi.squeeze, lv2))
+ bb.emit_func_output(gv)
+ mod = bb.get()
+
+ func_gv = mod.get_global_var("fused_exp_exp_squeeze")
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv = bb.emit(relax.Call(func_gv, [x]))
+ gv = bb.emit_output(lv)
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ def expected():
+ def fused_exp_exp_squeeze(x):
+ exp = topi.exp(x)
+ exp = topi.exp(exp)
+ squeeze = topi.squeeze(exp)
+ return squeeze
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv = bb.emit_te(fused_exp_exp_squeeze, x)
+ gv = bb.emit_output(lv)
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_fuse_with_tuple_as_param():
+ def before():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10],
"float32")]))
+ with bb.function("fused_exp_add", [x], attrs={"Primitive": True}):
+ with bb.dataflow():
+ lv0 = bb.emit(relax.TupleGetItem(x, 0))
+ lv1 = bb.emit(relax.TupleGetItem(x, 1))
+ lv2 = bb.emit_te(topi.exp, lv0)
+ gv = bb.emit_output(bb.call_te(topi.add, lv2, lv1))
+ bb.emit_func_output(gv)
+ mod = bb.get()
+
+ func_gv = mod.get_global_var("fused_exp_add")
+ x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10],
"float32")]))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ gv = bb.emit_output(relax.Call(func_gv, [x]))
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ def expected():
+ def fused_exp_add(x1, x2):
+ exp = topi.exp(x1)
+ return topi.add(exp, x2)
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tuple([R.Tensor([10], "float32"), R.Tensor([10],
"float32")]))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit(relax.TupleGetItem(x, 0))
+ lv1 = bb.emit(relax.TupleGetItem(x, 1))
+ gv = bb.emit_output(bb.call_te(fused_exp_add, lv0, lv1))
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_fuse_with_nested_tuple_as_param():
+ tuple_struct_info = R.Tuple(
+ [R.Tensor([10], "float32"), R.Tuple([R.Tensor([10], "float32"),
R.Tensor([10], "float32")])]
+ )
+
+ def before():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", tuple_struct_info)
+ with bb.function("fused_exp_add_add", [x], attrs={"Primitive": True}):
+ with bb.dataflow():
+ lv0 = bb.emit(relax.TupleGetItem(x, 0))
+ lv0_exp = bb.emit_te(topi.exp, lv0)
+ lv1 = bb.emit(relax.TupleGetItem(x, 1))
+ lv1_0 = bb.emit(relax.TupleGetItem(lv1, 0))
+ lv1_1 = bb.emit(relax.TupleGetItem(lv1, 1))
+ lv2 = bb.emit_te(topi.add, lv1_0, lv1_1)
+ gv = bb.emit_output(bb.call_te(topi.add, lv0_exp, lv2))
+ bb.emit_func_output(gv)
+ mod = bb.get()
+
+ func_gv = mod.get_global_var("fused_exp_add_add")
+ x = relax.Var("x", tuple_struct_info)
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ gv = bb.emit_output(relax.Call(func_gv, [x]))
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ def expected():
+ def fused_exp_add_add(x1, x2, x3):
+ exp = topi.exp(x1)
+ add = topi.add(x2, x3)
+ return topi.add(exp, add)
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", tuple_struct_info)
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit(relax.TupleGetItem(x, 0))
+ lv1 = bb.emit(relax.TupleGetItem(x, 1))
+ lv2 = bb.emit(relax.TupleGetItem(lv1, 0))
+ lv3 = bb.emit(relax.TupleGetItem(lv1, 1))
+ gv = bb.emit_output(bb.call_te(fused_exp_add_add, lv0, lv2,
lv3))
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_fuse_with_call_tir_in_main():
+ def before():
+ bb = relax.BlockBuilder()
+ x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
+ with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}):
+ with bb.dataflow():
+ lv = bb.emit_te(topi.exp, x1)
+ gv = bb.emit_output(bb.call_te(topi.squeeze, lv))
+ bb.emit_func_output(gv)
+ mod = bb.get()
+
+ func_gv = mod.get_global_var("fused_exp_squeeze")
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv0 = bb.emit(relax.Call(func_gv, [x]))
+ lv1 = bb.emit_te(topi.add, lv0, relax.const(1, "float32"))
+ gv = bb.emit_output(lv1)
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ def expected():
+ def fused_exp_squeeze(x):
+ exp = topi.exp(x)
+ squeeze = topi.squeeze(exp)
+ return squeeze
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv = bb.emit_te(fused_exp_squeeze, x)
+ lv2 = bb.emit_te(topi.add, lv, relax.const(1, "float32"))
+ gv = bb.emit_output(lv2)
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_fuse_with_const_in_argument():
+ def before():
+ bb = relax.BlockBuilder()
+ x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
+ x2 = relax.Var("x2", R.Tensor([], "float32"))
+ with bb.function("fused_add_exp_squeeze", [x1, x2],
attrs={"Primitive": True}):
+ with bb.dataflow():
+ lv0 = bb.emit_te(topi.add, x1, x2)
+ lv1 = bb.emit_te(topi.exp, lv0)
+ gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+ bb.emit_func_output(gv)
+ mod = bb.get()
+
+ func_gv = mod.get_global_var("fused_add_exp_squeeze")
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv = bb.emit(relax.Call(func_gv, [x, relax.const(1,
"float32")]))
+ gv = bb.emit_output(lv)
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ def expected():
+ def fused_add_exp_squeeze(x, y):
+ add = topi.add(x, y)
+ exp = topi.exp(add)
+ squeeze = topi.squeeze(exp)
+ return squeeze
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv = bb.emit_te(fused_add_exp_squeeze, x, relax.const(1,
"float32"))
+ gv = bb.emit_output(lv)
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_fuse_tuple_output():
+ def before():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ p0 = relax.Var("p0", R.Tensor([], "float32"))
+
+ with bb.function("fused_add_exp", [x, p0], attrs={"Primitive": True}):
+ with bb.dataflow():
+ gv0 = bb.emit_output(bb.call_te(topi.add, x, p0))
+ gv1 = bb.emit_output(bb.call_te(topi.exp, gv0))
+ bb.emit_func_output(relax.Tuple([gv0, gv1]))
+ fused_add_exp = bb.get().get_global_var("fused_add_exp")
+
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ p0 = relax.Var("p0", R.Tensor([], "float32"))
+ with bb.function("main", [x, p0]):
+ with bb.dataflow():
+ gv = bb.emit_output(relax.Call(fused_add_exp, [x, p0]))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected():
+ def fused_add_exp(x, p0):
+ add = topi.add(x, p0)
+ exp = topi.exp(add)
+ return add, exp
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ p0 = relax.Var("p0", R.Tensor([], "float32"))
+ with bb.function("main", [x, p0]):
+ with bb.dataflow():
+ gv = bb.emit_output(bb.call_te(fused_add_exp, x, p0))
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_fuse_with_immediate_tuple():
+ def before():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ y = relax.Var("y", R.Tensor([10, 20], "float32"))
+
+ with bb.function("fused_add", [x, y], attrs={"Primitive": True}):
+ with bb.dataflow():
+ lv_tuple = bb.emit(relax.Tuple([x, relax.Tuple([x, y])]))
+ lv_x = bb.emit(relax.TupleGetItem(lv_tuple, 0))
+ lv0 = bb.emit(relax.TupleGetItem(lv_tuple, 1))
+ lv_y = bb.emit(relax.TupleGetItem(lv0, 1))
+ gv = bb.emit_output(bb.call_te(topi.add, lv_x, lv_y))
+ bb.emit_func_output(gv)
+ fused_add = bb.get().get_global_var("fused_add")
+
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ y = relax.Var("y", R.Tensor([10, 20], "float32"))
+ with bb.function("main", [x, y]):
+ with bb.dataflow():
+ gv = bb.emit_output(relax.Call(fused_add, [x, y]))
+ bb.emit_func_output(gv)
+
+ return bb.get()
+
+ def expected():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ y = relax.Var("y", R.Tensor([10, 20], "float32"))
+ with bb.function("main", [x, y]):
+ with bb.dataflow():
+ gv = bb.emit_output(bb.call_te(topi.add, x, y,
primfunc_name_hint="fused_add"))
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ _check(before(), expected())
+
+
+def test_fuse_return_partial_result():
+ def te_argmax_idx_val(val):
+ from tvm import te
+
+ def f_combine(x, y):
+ lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
+ rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
+ return lhs, rhs
+
+ def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType):
+ return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1)
+
+ argmax = te.comm_reducer(f_combine, f_identity, name="argmax")
+ m, n = val.shape
+ k = te.reduce_axis((0, n), "k")
+ max_idx, max_val = te.compute(
+ (m,), lambda i: argmax((k.var, val[i, k]), axis=k), name="argmax"
+ )
+ return max_idx, max_val
+
+ def before():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ offset = relax.Var("offset", R.Tensor([10], "int32"))
+ with bb.function("fused_argmax_add", [x, offset], attrs={"Primitive":
True}):
+ with bb.dataflow():
+ lv = bb.emit_te(te_argmax_idx_val, x)
+ idx = bb.emit(relax.TupleGetItem(lv, 0))
+ gv = bb.emit_output(bb.call_te(topi.add, idx, offset))
+ bb.emit_func_output(gv)
+ mod = bb.get()
+
+ func_gv = mod.get_global_var("fused_argmax_add")
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ offset = relax.Var("x", R.Tensor([10], "int32"))
+ with bb.function("main", [x, offset]):
+ with bb.dataflow():
+ gv = bb.emit_output(relax.Call(func_gv, [x, offset]))
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ def expected():
+ def fused_argmax_add(x, offset):
+ idx, value = te_argmax_idx_val(x)
+ idx = topi.add(idx, offset)
+ return idx
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor([10, 20], "float32"))
+ offset = relax.Var("offset", R.Tensor([10], "int32"))
+ with bb.function("main", [x, offset]):
+ with bb.dataflow():
+ gv = bb.emit_output(bb.call_te(fused_argmax_add, x, offset))
+ bb.emit_func_output(gv)
+ return bb.get()
+
+ _check(before(), expected())
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index f6d2e4c20e..6e9e14d3dc 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1073,5 +1073,4 @@ def test_class_normalize():
if __name__ == "__main__":
- test_cross_function_call()
tvm.testing.main()