This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ecdc68cc5e [Unity][Pass] Fix FuseOps error if there is no output of a
given group (#14354)
ecdc68cc5e is described below
commit ecdc68cc5e85d54e25e9cd988fdc668081bb468a
Author: Siyuan Feng <[email protected]>
AuthorDate: Wed Mar 22 11:45:27 2023 +0800
[Unity][Pass] Fix FuseOps error if there is no output of a given group
(#14354)
This PR is to fix the FuseOps error if there is no output of a given group,
although the pass `DeadCodeElimination` can solve the problem, it is better
to enhance the robustness of the pass `FuseOps`.
---
src/relax/transform/fuse_ops.cc | 42 +++++---
tests/python/relax/test_transform_fuse_ops.py | 132 ++++++++++++++++++++++++++
2 files changed, 159 insertions(+), 15 deletions(-)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 15bcf3513c..24f068c03f 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -449,22 +449,28 @@ class FunctionCreator : public ExprMutator {
var_remap_[var_binding->var->vid] = output_var;
outputs.Set(*output_idx, output_var);
} else {
- // Case 2. It is an internel binding, add it to the binding list.
+ // Case 2. It is an internal binding, add it to the binding list.
VisitBinding(binding);
}
}
// Step 3. Finish constructing the new block.
BindingBlock new_block = builder_->EndBlock();
- ICHECK(!outputs.empty()) << "At least one output is required.";
- Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs);
- body = builder_->Normalize(body);
- body = builder_->Normalize(SeqExpr({new_block}, body));
- group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
- function_ = Function(/*params=*/params_, //
- /*body=*/body, //
- /*ret_struct_info=*/NullOpt, //
- /*attrs=*/DictAttrs(group_attrs));
+ if (outputs.empty()) {
+ // If the result is not used outside
+ LOG(WARNING) << "There are dead codes in the current IRModule, please
run the "
+ "DeadCodeElimination Pass before FuseOps";
+ function_ = NullOpt;
+ } else {
+ Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs);
+ body = builder_->Normalize(body);
+ body = builder_->Normalize(SeqExpr({new_block}, body));
+ group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1));
+ function_ = Function(/*params=*/params_, //
+ /*body=*/body, //
+ /*ret_struct_info=*/NullOpt, //
+ /*attrs=*/DictAttrs(group_attrs));
+ }
}
/*! \brief The original bindings of the function */
@@ -476,7 +482,7 @@ class FunctionCreator : public ExprMutator {
/*! \brief The name for the fused function */
String name_hint_ = "fused";
/*! \brief The constructed Relax function */
- Function function_{nullptr};
+ Optional<Function> function_ = NullOpt;
private:
std::optional<size_t> GetOutputIndex(Var v) {
@@ -648,7 +654,7 @@ class OperatorFusor : public ExprMutator {
std::unordered_map<Group*, std::vector<Var>> pending_tuple_get;
// A grouped function which returns a tuple requires attaching
TupleGetItem to each element and
- // remapping variables in earlier bindings approriately. Thus, a binding
whose value depends on
+ // remapping variables in earlier bindings appropriately. Thus, a binding
whose value depends on
// some elements of a tuple from other group's function must be emitted
after a call to the
// tuple-producing function is emitted and remapping is done.
// To guarantee this, we process bindings in the order of the topological
sort of the group
@@ -666,10 +672,16 @@ class OperatorFusor : public ExprMutator {
ICHECK(it_creator != group2func_.end());
const FunctionCreator& func_info = it_creator->second;
+ if (!func_info.function_.defined()) {
+ // The function is not created yet, so we skip the binding.
+ continue;
+ }
+ const Function& func = func_info.function_.value();
+
// If this binding belongs to a group whose output is a tuple, the
original bound variable
// needs to be remapped to the output of TupleGetItem after the
corresponding tuple is
// emitted.
- if (IsTupleOutput(func_info.function_) &&
tuple_get_indices_.count(binding->var.get())) {
+ if (IsTupleOutput(func) && tuple_get_indices_.count(binding->var.get()))
{
pending_tuple_get[group].push_back(binding->var);
}
@@ -684,7 +696,7 @@ class OperatorFusor : public ExprMutator {
"is supposed to be a variable binding";
// Step a. Add the grouped function to the IRModule
- GlobalVar gv = builder_->AddFunction(func_info.function_,
func_info.name_hint_);
+ GlobalVar gv = builder_->AddFunction(func, func_info.name_hint_);
// Step b. Create the call to the deduplicated function, and then emit
the call.
// - If this binding is an output binding, emit an output variable.
@@ -699,7 +711,7 @@ class OperatorFusor : public ExprMutator {
}
// Step c. Update the mapping used for the remapping of the binding
variables.
- if (IsTupleOutput(func_info.function_)) {
+ if (IsTupleOutput(func)) {
// If the output is a tuple, attach TupleGetItem to all tuple
elements, and
// remap variables approriately.
// The variables that need to be remapped and the corresponding tuple
indices are
diff --git a/tests/python/relax/test_transform_fuse_ops.py
b/tests/python/relax/test_transform_fuse_ops.py
index a835eacd2c..8f7d8bf40f 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -1122,5 +1122,137 @@ def test_multiple_paths():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_dead_group():
+
+ # fmt: off
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1:
R.Tensor((1, 128), dtype="float32"), linear1_bias: R.Tensor((128,),
dtype="float32"), linear1_weight: R.Tensor((128, 784), dtype="float32"),
linear2_bias: R.Tensor((10,), dtype="float32"), linear2_weight: R.Tensor((10,
128), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tensor((784, 128), dtype="float32") =
R.permute_dims(linear1_weight, axes=None)
+ lv1: R.Tensor((1, 128), dtype="float32") = R.matmul(inp_0, lv,
out_dtype="float32")
+ lv2: R.Tensor((1, 128), dtype="float32") = R.add(lv1,
linear1_bias)
+ lv3: R.Tensor((1, 128), dtype="float32") = R.nn.relu(lv2)
+ lv4: R.Tensor((128, 10), dtype="float32") =
R.permute_dims(linear2_weight, axes=None)
+ lv5: R.Tensor((1, 10), dtype="float32") = R.matmul(inp_1, lv4,
out_dtype="float32")
+ lv6: R.Tensor((1, 10), dtype="float32") = R.add(lv5,
linear2_bias)
+ gv: R.Tensor((1, 10), dtype="float32") = lv6
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)),
"float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add:
T.Buffer((T.int64(1), T.int64(128)), "float32")):
+ T.func_attr({"op_pattern": 0, "tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(T.int64(1), T.int64(128)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax0, v_ax1],
rxplaceholder_1[v_ax1])
+ T.writes(T_add[v_ax0, v_ax1])
+ T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] +
rxplaceholder_1[v_ax1]
+
+ @T.prim_func
+ def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)),
"float32"), rxplaceholder_1: T.Buffer((T.int64(10),), "float32"), T_add:
T.Buffer((T.int64(1), T.int64(10)), "float32")):
+ T.func_attr({"op_pattern": 0, "tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(T.int64(1), T.int64(10)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax0, v_ax1],
rxplaceholder_1[v_ax1])
+ T.writes(T_add[v_ax0, v_ax1])
+ T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] +
rxplaceholder_1[v_ax1]
+
+ @T.prim_func
+ def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)),
"float32"), rxplaceholder_1: T.Buffer((T.int64(784), T.int64(128)), "float32"),
matmul_1: T.Buffer((T.int64(1), T.int64(128)), "float32")):
+ T.func_attr({"op_pattern": 4, "tir.noalias": True})
+ # with T.block("root"):
+ for i0, i1, k in T.grid(T.int64(1), T.int64(128), T.int64(784)):
+ with T.block("matmul"):
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+ T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k,
v_i1])
+ T.writes(matmul_1[v_i0, v_i1])
+ with T.init():
+ matmul_1[v_i0, v_i1] = T.float32(0)
+ matmul_1[v_i0, v_i1] = matmul_1[v_i0, v_i1] +
rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1]
+
+ @T.prim_func
+ def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)),
"float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(10)), "float32"),
matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")):
+ T.func_attr({"op_pattern": 4, "tir.noalias": True})
+ # with T.block("root"):
+ for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(128)):
+ with T.block("matmul"):
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+ T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k,
v_i1])
+ T.writes(matmul[v_i0, v_i1])
+ with T.init():
+ matmul[v_i0, v_i1] = T.float32(0)
+ matmul[v_i0, v_i1] = matmul[v_i0, v_i1] +
rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1]
+
+ @T.prim_func
+ def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)),
"float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")):
+ T.func_attr({"op_pattern": 0, "tir.noalias": True})
+ # with T.block("root"):
+ for i0, i1 in T.grid(T.int64(1), T.int64(128)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(rxplaceholder[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.max(rxplaceholder[v_i0, v_i1],
T.float32(0))
+
+ @T.prim_func
+ def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)),
"float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")):
+ T.func_attr({"op_pattern": 2, "tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(T.int64(784), T.int64(128)):
+ with T.block("T_transpose"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax1, v_ax0])
+ T.writes(T_transpose[v_ax0, v_ax1])
+ T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
+
+ @T.prim_func
+ def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)),
"float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")):
+ T.func_attr({"op_pattern": 2, "tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(T.int64(128), T.int64(10)):
+ with T.block("T_transpose"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax1, v_ax0])
+ T.writes(T_transpose[v_ax0, v_ax1])
+ T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
+
+ @R.function
+ def fused_matmul1_add1(inp_1: R.Tensor((1, 128), dtype="float32"),
lv4: R.Tensor((128, 10), dtype="float32"), linear2_bias: R.Tensor((10,),
dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ cls = Expected
+ with R.dataflow():
+ lv5 = R.call_tir(cls.matmul1, (inp_1, lv4),
out_sinfo=R.Tensor((1, 10), dtype="float32"))
+ gv = R.call_tir(cls.add1, (lv5, linear2_bias),
out_sinfo=R.Tensor((1, 10), dtype="float32"))
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1:
R.Tensor((1, 128), dtype="float32"), linear1_bias: R.Tensor((128,),
dtype="float32"), linear1_weight: R.Tensor((128, 784), dtype="float32"),
linear2_bias: R.Tensor((10,), dtype="float32"), linear2_weight: R.Tensor((10,
128), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ cls = Expected
+ with R.dataflow():
+ lv = R.call_tir(cls.transpose, (linear1_weight,),
out_sinfo=R.Tensor((784, 128), dtype="float32"))
+ lv4 = R.call_tir(cls.transpose1, (linear2_weight,),
out_sinfo=R.Tensor((128, 10), dtype="float32"))
+ lv_1: R.Tensor((1, 10), dtype="float32") =
cls.fused_matmul1_add1(inp_1, lv4, linear2_bias)
+ gv: R.Tensor((1, 10), dtype="float32") = lv_1
+ R.output(gv)
+ return gv
+
+ # fmt: on
+
+ mod = relax.transform.LegalizeOps()(Module)
+ _check(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()