This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 73a45e3022e1e77b8afe200196a6747c0010e8bd
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Feb 18 22:19:54 2023 -0500

    [Unity][Pass] FuseOps FuseTIR fixes (#14044)
    
    This PR fixes two bugs of FuseOps and FuseTIR:
    
    It fixes FuseOps who only rewrites the "main" function of the
    IRModule. After the fix, FuseOps now goes through each non-primitive
    Relax function. Test cases for both FuseOps and FuseTIR sides are added
    so ensure that both of the two passes work for cases of multiple Relax
    functions.
    
    It also fixes FuseOps and FuseTIR who did not take "call_dps_packed" style
    "call_tir" into account. The previous behavior will directly downcast
    the first argument of "call_tir" to GlobalVar, which is not right when
    the "call_tir" is in "call_dps_packed" stype and the first argument is
    a PackedFunc. With this fix, FuseOps and FuseTIR will skip such
    "call_tir"s. Tests for both CallTIR and CallOps are added accordingly.
---
 src/relax/transform/fuse_ops.cc               |  54 +++++-----
 src/relax/transform/fuse_tir.cc               |  15 +--
 tests/python/relax/test_transform_fuse_ops.py |  81 ++++++++++++++-
 tests/python/relax/test_transform_fuse_tir.py | 141 +++++++++++++++++++++++++-
 4 files changed, 252 insertions(+), 39 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index f3559b72da..0a0209bb87 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -100,11 +100,15 @@ class GraphCreator : public ExprVisitor {
    * \return The created IndexedForwardGraph
    */
   static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) {
-    // Since cross-function call is not supported yet, FuseOps only serves the 
entry function, whose
-    // name is "main".
-    auto relax_func = Downcast<Function>(mod->Lookup("main"));
     GraphCreator creator(mod, arena);
-    creator(relax_func);
+    for (const auto& it : mod->functions) {
+      // Only visit Relax function without attr kPrimitive.
+      const auto* func = it.second.as<FunctionNode>();
+      if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive)) {
+        continue;
+      }
+      creator(GetRef<Function>(func));
+    }
 
     // The algorithm of the graph creator ensures that each created node will 
be added to the
     // post-dfs order and will be set its op pattern. Thus we check whether 
all these containers
@@ -178,25 +182,26 @@ class GraphCreator : public ExprVisitor {
     // recurse into the call expression.
     const auto* op = call->op.as<OpNode>();
     if (op == call_tir_op_.get()) {
-      const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
-      tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));
+      // Skip ExternFunc for call_dps_packed.
+      if (const auto* global_var = call->args[0].as<GlobalVarNode>()) {
+        tir::PrimFunc func = 
Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(global_var)));
 
-      // Override args for call_tir
-      args = Downcast<Tuple>(call->args[1])->fields;
+        // Override args for call_tir
+        args = Downcast<Tuple>(call->args[1])->fields;
 
-      // TODO(tvm-team): handle the shape argument (args[3])
-      Optional<Integer> opt_pattern = func->GetAttr<Integer>("op_pattern");
-      if (opt_pattern.defined()) {
-        pattern = 
static_cast<OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
-      } else {
-        pattern = OpPatternKind::kOpaque;
+        Optional<Integer> opt_pattern = func->GetAttr<Integer>("op_pattern");
+        if (opt_pattern.defined()) {
+          pattern = 
static_cast<OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
+        } else {
+          pattern = OpPatternKind::kOpaque;
+        }
       }
     }
     // The pattern of the current binding variable node is set to the pattern 
of this operator.
     SetNodePattern(binding_var_node, pattern);
     // Visit all call args
     for (const Expr& arg : args) {
-      ICHECK(IsLeaf(arg));
+      ICHECK(IsLeafOrTuple(arg));
       VisitLeaf(arg, binding_var_node, pattern);
     }
   }
@@ -226,6 +231,10 @@ class GraphCreator : public ExprVisitor {
   void VisitLeaf(const Expr& leaf_expr, IndexedForwardGraph::Node* 
binding_var_node,
                  const OpPatternKind& pattern) {
     ICHECK_NOTNULL(binding_var_node);
+    if (!leaf_expr->IsInstance<LeafExprNode>()) {
+      // Skip GlobalVar, ExternFunc, OpNode.
+      return;
+    }
 
     // Recursive visit if it's Tuple
     if (const auto* tuple = leaf_expr.as<TupleNode>()) {
@@ -253,21 +262,6 @@ class GraphCreator : public ExprVisitor {
 
   /********** Helper Functions **********/
 
-  /*!
-   * \brief Check whether the expression is a leaf expression
-   * \param expr The expression to be checked
-   * \return Whether the expression is a leaf expression
-   * \note In order to avoid too much refactor, this method is a simple 
copy-paste of the is-leaf
-   * check in "block_builder.cc". And it should be refactored in the future.
-   * \sa src/relax/ir/block_builder.cc
-   */
-  static bool IsLeaf(const Expr& expr) {
-    // NOTE: Tuples are treated as leaf nodes for ergonomics
-    return expr.as<VarNode>() || expr.as<GlobalVarNode>() || 
expr.as<ConstantNode>() ||
-           expr.as<ShapeExprNode>() || expr.as<ExternFuncNode>() || 
expr.as<OpNode>() ||
-           expr.as<TupleNode>();
-  }
-
   /*!
    * \brief Create a graph node corresponding to the input key
    * \param key The object which is used to create the graph node
diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc
index fa5c296d27..925f09d85d 100644
--- a/src/relax/transform/fuse_tir.cc
+++ b/src/relax/transform/fuse_tir.cc
@@ -670,14 +670,15 @@ class TIRFuseMutator : public ExprMutator {
       }
     } else if (call->op == call_tir_op_) {
       // Case 2. It is a call_tir, re-emit the PrimFunc.
-      GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
-      tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
-      GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
-      return Call(call->op, {new_gv, call->args[1]}, call->attrs, 
call->sinfo_args, call->span);
-    } else {
-      // Case 3. CallNode in other types. Leave it as it is.
-      return call;
+      if (const auto* gv = call->args[0].as<GlobalVarNode>()) {
+        tir::PrimFunc func = 
Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(gv)));
+        GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint);
+        return Call(call->op, {new_gv, call->args[1]}, call->attrs, 
call->sinfo_args, call->span);
+      }
     }
+
+    // Case 3. CallNode in other types. Leave it as it is.
+    return call;
   }
 
   /********** Helper Functions **********/
diff --git a/tests/python/relax/test_transform_fuse_ops.py 
b/tests/python/relax/test_transform_fuse_ops.py
index 1a228bb268..6fad4f8165 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -18,7 +18,7 @@
 import tvm
 import tvm.testing
 from tvm import relax, topi
-from tvm.script import relax as R
+from tvm.script import ir as I, relax as R
 
 
 def _check(mod_actual, mod_expected):
@@ -755,5 +755,84 @@ def test_softmax():
     _check(before(), expected())
 
 
+def test_multiple_relax_functions():
+    def before():
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("func1", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+                lv1 = bb.emit_te(topi.exp, lv0)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+            bb.emit_func_output(gv)
+
+        x = relax.Var("x", R.Tensor([20, 10], "float32"))
+        with bb.function("func2", [x]):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, relax.const(1, "float32"))
+                lv1 = bb.emit_te(topi.exp, lv0)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    def expected():
+        bb = relax.BlockBuilder()
+
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        p0 = relax.Var("p0", R.Tensor((), "float32"))
+        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 
1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, p0)
+                lv1 = bb.emit_te(topi.exp, lv0)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+            bb.emit_func_output(gv)
+        fused_add_exp_squeeze = 
bb.get().get_global_var("fused_add_exp_squeeze")
+
+        x = relax.Var("x", R.Tensor([20, 10], "float32"))
+        p0 = relax.Var("p0", R.Tensor((), "float32"))
+        with bb.function("fused_add1_exp1_squeeze1", [x, p0], 
attrs={"Primitive": 1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, p0)
+                lv1 = bb.emit_te(topi.exp, lv0)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+            bb.emit_func_output(gv)
+        fused_add1_exp1_squeeze1 = 
bb.get().get_global_var("fused_add1_exp1_squeeze1")
+
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("func1", [x]):
+            with bb.dataflow():
+                gv = bb.emit_output(
+                    relax.Call(fused_add_exp_squeeze, [x, relax.const(1, 
"float32")])
+                )
+            bb.emit_func_output(gv)
+
+        x = relax.Var("x", R.Tensor([20, 10], "float32"))
+        with bb.function("func2", [x]):
+            with bb.dataflow():
+                gv = bb.emit_output(
+                    relax.Call(fused_add1_exp1_squeeze1, [x, relax.const(1, 
"float32")])
+                )
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    _check(before(), expected())
+
+
+def test_skip_call_dps_packed():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")):
+            with R.dataflow():
+                y = R.call_tir("func_packed_dps", x, R.Tensor((2, 3), 
"float32"))
+                R.output(y)
+            return y
+
+    # FuseOps should does no change to it.
+    _check(Module, Module)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fuse_tir.py 
b/tests/python/relax/test_transform_fuse_tir.py
index 91edab2bbb..c2784edec7 100644
--- a/tests/python/relax/test_transform_fuse_tir.py
+++ b/tests/python/relax/test_transform_fuse_tir.py
@@ -18,7 +18,7 @@
 import tvm
 import tvm.testing
 from tvm import relax, topi
-from tvm.script import relax as R
+from tvm.script import ir as I, relax as R, tir as T
 
 
 def _check(mod_before, mod_expected):
@@ -559,5 +559,144 @@ def test_fuse_return_partial_result():
     _check(before(), expected())
 
 
+def test_multiple_relax_functions():
+    def before():
+        bb = relax.BlockBuilder()
+
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        p0 = relax.Var("p0", R.Tensor((), "float32"))
+        with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 
1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, p0)
+                lv1 = bb.emit_te(topi.exp, lv0)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+            bb.emit_func_output(gv)
+        fused_add_exp_squeeze = 
bb.get().get_global_var("fused_add_exp_squeeze")
+
+        x = relax.Var("x", R.Tensor([20, 10], "float32"))
+        p0 = relax.Var("p0", R.Tensor((), "float32"))
+        with bb.function("fused_add1_exp1_squeeze1", [x, p0], 
attrs={"Primitive": 1}):
+            with bb.dataflow():
+                lv0 = bb.emit_te(topi.add, x, p0)
+                lv1 = bb.emit_te(topi.exp, lv0)
+                gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
+            bb.emit_func_output(gv)
+        fused_add1_exp1_squeeze1 = 
bb.get().get_global_var("fused_add1_exp1_squeeze1")
+
+        x = relax.Var("x", R.Tensor([10, 20], "float32"))
+        with bb.function("func1", [x]):
+            with bb.dataflow():
+                gv = bb.emit_output(
+                    relax.Call(fused_add_exp_squeeze, [x, relax.const(1, 
"float32")])
+                )
+            bb.emit_func_output(gv)
+
+        x = relax.Var("x", R.Tensor([20, 10], "float32"))
+        with bb.function("func2", [x]):
+            with bb.dataflow():
+                gv = bb.emit_output(
+                    relax.Call(fused_add1_exp1_squeeze1, [x, relax.const(1, 
"float32")])
+                )
+            bb.emit_func_output(gv)
+
+        return bb.get()
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def func1(x: R.Tensor((10, 20), dtype="float32")) -> R.Tensor((10, 
20), dtype="float32"):
+            with R.dataflow():
+                gv2 = R.call_tir(
+                    fused_add_exp_squeeze,
+                    (x, R.const(1, "float32")),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv2)
+            return gv2
+
+        @R.function
+        def func2(x: R.Tensor((20, 10), dtype="float32")) -> R.Tensor((20, 
10), dtype="float32"):
+            with R.dataflow():
+                gv3 = R.call_tir(
+                    fused_add1_exp1_squeeze1,
+                    (x, R.const(1, "float32")),
+                    out_sinfo=R.Tensor((20, 10), dtype="float32"),
+                )
+                R.output(gv3)
+            return gv3
+
+        @T.prim_func
+        def fused_add1_exp1_squeeze1(
+            x: T.Buffer((T.int64(20), T.int64(10)), "float32"),
+            p0: T.Buffer((), "float32"),
+            T_squeeze: T.Buffer((T.int64(20), T.int64(10)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            T_add = T.alloc_buffer((T.int64(20), T.int64(10)))
+            compute = T.alloc_buffer((T.int64(20), T.int64(10)))
+            for ax0, ax1 in T.grid(T.int64(20), T.int64(10)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1], p0[()])
+                    T.writes(T_add[v_ax0, v_ax1])
+                    T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+            for i0, i1 in T.grid(T.int64(20), T.int64(10)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(T_add[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.exp(T_add[v_i0, v_i1])
+            for ax0, ax1 in T.grid(T.int64(20), T.int64(10)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(compute[v_ax0, v_ax1])
+                    T.writes(T_squeeze[v_ax0, v_ax1])
+                    T_squeeze[v_ax0, v_ax1] = compute[v_ax0, v_ax1]
+
+        @T.prim_func
+        def fused_add_exp_squeeze(
+            x: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+            p0: T.Buffer((), "float32"),
+            T_squeeze: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            T_add = T.alloc_buffer((T.int64(10), T.int64(20)))
+            compute = T.alloc_buffer((T.int64(10), T.int64(20)))
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1], p0[()])
+                    T.writes(T_add[v_ax0, v_ax1])
+                    T_add[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(T_add[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.exp(T_add[v_i0, v_i1])
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(compute[v_ax0, v_ax1])
+                    T.writes(T_squeeze[v_ax0, v_ax1])
+                    T_squeeze[v_ax0, v_ax1] = compute[v_ax0, v_ax1]
+
+    _check(before(), Expected)
+
+
+def test_skip_call_dps_packed():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")):
+            with R.dataflow():
+                y = R.call_tir("func_packed_dps", x, R.Tensor((2, 3), 
"float32"))
+                R.output(y)
+            return y
+
+    # FuseTIR should does no change to it.
+    _check(Module, Module)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to