This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ecdc68cc5e [Unity][Pass] Fix FuseOps error if there is no output of a 
given group (#14354)
ecdc68cc5e is described below

commit ecdc68cc5e85d54e25e9cd988fdc668081bb468a
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Mar 22 11:45:27 2023 +0800

    [Unity][Pass] Fix FuseOps error if there is no output of a given group 
(#14354)
    
    This PR is to fix the FuseOps error if there is no output of a given group,
    although the pass `DeadCodeElimination` can solve the problem, it is better
    to enhance the robustness of the pass `FuseOps`.
---
 src/relax/transform/fuse_ops.cc               |  42 +++++---
 tests/python/relax/test_transform_fuse_ops.py | 132 ++++++++++++++++++++++++++
 2 files changed, 159 insertions(+), 15 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 15bcf3513c..24f068c03f 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -449,22 +449,28 @@ class FunctionCreator : public ExprMutator {
         var_remap_[var_binding->var->vid] = output_var;
         outputs.Set(*output_idx, output_var);
       } else {
-        // Case 2. It is an internel binding, add it to the binding list.
+        // Case 2. It is an internal binding, add it to the binding list.
         VisitBinding(binding);
       }
     }
 
     // Step 3. Finish constructing the new block.
     BindingBlock new_block = builder_->EndBlock();
-    ICHECK(!outputs.empty()) << "At least one output is required.";
-    Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs);
-    body = builder_->Normalize(body);
-    body = builder_->Normalize(SeqExpr({new_block}, body));
-    group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
-    function_ = Function(/*params=*/params_,           //
-                         /*body=*/body,                //
-                         /*ret_struct_info=*/NullOpt,  //
-                         /*attrs=*/DictAttrs(group_attrs));
+    if (outputs.empty()) {
+      // If the result is not used outside
+      LOG(WARNING) << "There are dead codes in the current IRModule, please 
run the "
+                      "DeadCodeElimination Pass before FuseOps";
+      function_ = NullOpt;
+    } else {
+      Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs);
+      body = builder_->Normalize(body);
+      body = builder_->Normalize(SeqExpr({new_block}, body));
+      group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
+      function_ = Function(/*params=*/params_,           //
+                           /*body=*/body,                //
+                           /*ret_struct_info=*/NullOpt,  //
+                           /*attrs=*/DictAttrs(group_attrs));
+    }
   }
 
   /*! \brief The original bindings of the function */
@@ -476,7 +482,7 @@ class FunctionCreator : public ExprMutator {
   /*! \brief The name for the fused function */
   String name_hint_ = "fused";
   /*! \brief The constructed Relax function */
-  Function function_{nullptr};
+  Optional<Function> function_ = NullOpt;
 
  private:
   std::optional<size_t> GetOutputIndex(Var v) {
@@ -648,7 +654,7 @@ class OperatorFusor : public ExprMutator {
     std::unordered_map<Group*, std::vector<Var>> pending_tuple_get;
 
     // A grouped function which returns a tuple requires attaching 
TupleGetItem to each element and
-    // remapping variables in earlier bindings approriately. Thus, a binding 
whose value depends on
+    // remapping variables in earlier bindings appropriately. Thus, a binding 
whose value depends on
     // some elements of a tuple from other group's function must be emitted 
after a call to the
     // tuple-producing function is emitted and remapping is done.
     // To guarantee this, we process bindings in the order of the topological 
sort of the group
@@ -666,10 +672,16 @@ class OperatorFusor : public ExprMutator {
       ICHECK(it_creator != group2func_.end());
       const FunctionCreator& func_info = it_creator->second;
 
+      if (!func_info.function_.defined()) {
+        // The function is not created yet, so we skip the binding.
+        continue;
+      }
+      const Function& func = func_info.function_.value();
+
       // If this binding belongs to a group whose output is a tuple, the 
original bound variable
       // needs to be remapped to the output of TupleGetItem after the 
corresponding tuple is
       // emitted.
-      if (IsTupleOutput(func_info.function_) && 
tuple_get_indices_.count(binding->var.get())) {
+      if (IsTupleOutput(func) && tuple_get_indices_.count(binding->var.get())) 
{
         pending_tuple_get[group].push_back(binding->var);
       }
 
@@ -684,7 +696,7 @@ class OperatorFusor : public ExprMutator {
                                         "is supposed to be a variable binding";
 
       // Step a. Add the grouped function to the IRModule
-      GlobalVar gv = builder_->AddFunction(func_info.function_, 
func_info.name_hint_);
+      GlobalVar gv = builder_->AddFunction(func, func_info.name_hint_);
 
       // Step b. Create the call to the deduplicated function, and then emit 
the call.
       //  - If this binding is an output binding, emit an output variable.
@@ -699,7 +711,7 @@ class OperatorFusor : public ExprMutator {
       }
 
       // Step c. Update the mapping used for the remapping of the binding 
variables.
-      if (IsTupleOutput(func_info.function_)) {
+      if (IsTupleOutput(func)) {
         // If the output is a tuple, attach TupleGetItem to all tuple 
elements, and
         // remap variables approriately.
         // The variables that need to be remapped and the corresponding tuple 
indices are
diff --git a/tests/python/relax/test_transform_fuse_ops.py 
b/tests/python/relax/test_transform_fuse_ops.py
index a835eacd2c..8f7d8bf40f 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -1122,5 +1122,137 @@ def test_multiple_paths():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_dead_group():
+
+    # fmt: off
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: 
R.Tensor((1, 128), dtype="float32"), linear1_bias: R.Tensor((128,), 
dtype="float32"), linear1_weight: R.Tensor((128, 784), dtype="float32"), 
linear2_bias: R.Tensor((10,), dtype="float32"), linear2_weight: R.Tensor((10, 
128), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((784, 128), dtype="float32") = 
R.permute_dims(linear1_weight, axes=None)
+                lv1: R.Tensor((1, 128), dtype="float32") = R.matmul(inp_0, lv, 
out_dtype="float32")
+                lv2: R.Tensor((1, 128), dtype="float32") = R.add(lv1, 
linear1_bias)
+                lv3: R.Tensor((1, 128), dtype="float32") = R.nn.relu(lv2)
+                lv4: R.Tensor((128, 10), dtype="float32") = 
R.permute_dims(linear2_weight, axes=None)
+                lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(inp_1, lv4, 
out_dtype="float32")
+                lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5, 
linear2_bias)
+                gv: R.Tensor((1, 10), dtype="float32") = lv6
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), 
"float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: 
T.Buffer((T.int64(1), T.int64(128)), "float32")):
+            T.func_attr({"op_pattern": 0, "tir.noalias": True})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(T.int64(1), T.int64(128)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax0, v_ax1], 
rxplaceholder_1[v_ax1])
+                    T.writes(T_add[v_ax0, v_ax1])
+                    T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + 
rxplaceholder_1[v_ax1]
+
+        @T.prim_func
+        def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), 
"float32"), rxplaceholder_1: T.Buffer((T.int64(10),), "float32"), T_add: 
T.Buffer((T.int64(1), T.int64(10)), "float32")):
+            T.func_attr({"op_pattern": 0, "tir.noalias": True})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax0, v_ax1], 
rxplaceholder_1[v_ax1])
+                    T.writes(T_add[v_ax0, v_ax1])
+                    T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + 
rxplaceholder_1[v_ax1]
+
+        @T.prim_func
+        def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), 
"float32"), rxplaceholder_1: T.Buffer((T.int64(784), T.int64(128)), "float32"), 
matmul_1: T.Buffer((T.int64(1), T.int64(128)), "float32")):
+            T.func_attr({"op_pattern": 4, "tir.noalias": True})
+            # with T.block("root"):
+            for i0, i1, k in T.grid(T.int64(1), T.int64(128), T.int64(784)):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+                    T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, 
v_i1])
+                    T.writes(matmul_1[v_i0, v_i1])
+                    with T.init():
+                        matmul_1[v_i0, v_i1] = T.float32(0)
+                    matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] + 
rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1]
+
+        @T.prim_func
+        def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), 
"float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(10)), "float32"), 
matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
+            T.func_attr({"op_pattern": 4, "tir.noalias": True})
+            # with T.block("root"):
+            for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(128)):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+                    T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, 
v_i1])
+                    T.writes(matmul[v_i0, v_i1])
+                    with T.init():
+                        matmul[v_i0, v_i1] = T.float32(0)
+                    matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + 
rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1]
+
+        @T.prim_func
+        def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), 
"float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")):
+            T.func_attr({"op_pattern": 0, "tir.noalias": True})
+            # with T.block("root"):
+            for i0, i1 in T.grid(T.int64(1), T.int64(128)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(rxplaceholder[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1], 
T.float32(0))
+
+        @T.prim_func
+        def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), 
"float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
+            T.func_attr({"op_pattern": 2, "tir.noalias": True})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
+                with T.block("T_transpose"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax1, v_ax0])
+                    T.writes(T_transpose[v_ax0, v_ax1])
+                    T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
+
+        @T.prim_func
+        def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), 
"float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
+            T.func_attr({"op_pattern": 2, "tir.noalias": True})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
+                with T.block("T_transpose"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax1, v_ax0])
+                    T.writes(T_transpose[v_ax0, v_ax1])
+                    T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
+
+        @R.function
+        def fused_matmul1_add1(inp_1: R.Tensor((1, 128), dtype="float32"), 
lv4: R.Tensor((128, 10), dtype="float32"), linear2_bias: R.Tensor((10,), 
dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Expected
+            with R.dataflow():
+                lv5 = R.call_tir(cls.matmul1, (inp_1, lv4), 
out_sinfo=R.Tensor((1, 10), dtype="float32"))
+                gv = R.call_tir(cls.add1, (lv5, linear2_bias), 
out_sinfo=R.Tensor((1, 10), dtype="float32"))
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: 
R.Tensor((1, 128), dtype="float32"), linear1_bias: R.Tensor((128,), 
dtype="float32"), linear1_weight: R.Tensor((128, 784), dtype="float32"), 
linear2_bias: R.Tensor((10,), dtype="float32"), linear2_weight: R.Tensor((10, 
128), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            cls = Expected
+            with R.dataflow():
+                lv = R.call_tir(cls.transpose, (linear1_weight,), 
out_sinfo=R.Tensor((784, 128), dtype="float32"))
+                lv4 = R.call_tir(cls.transpose1, (linear2_weight,), 
out_sinfo=R.Tensor((128, 10), dtype="float32"))
+                lv_1: R.Tensor((1, 10), dtype="float32") = 
cls.fused_matmul1_add1(inp_1, lv4, linear2_bias)
+                gv: R.Tensor((1, 10), dtype="float32") = lv_1
+                R.output(gv)
+            return gv
+
+    # fmt: on
+
+    mod = relax.transform.LegalizeOps()(Module)
+    _check(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to