This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 8ea976276c [Unity] Properly handle tuple-outputting function in
`FuseOpsByPattern` (#14525)
8ea976276c is described below
commit 8ea976276cde885359280ec1a7d7d78d4b2c06bb
Author: masahi <[email protected]>
AuthorDate: Fri Apr 7 21:25:18 2023 +0900
[Unity] Properly handle tuple-outputting function in `FuseOpsByPattern`
(#14525)
* correctly handle tuple output function
* properly handle nested tuple case
* typo
* fix MergeCompositFunction test
---
src/relax/transform/fuse_ops.cc | 28 ++++++--
.../relax/test_transform_fuse_ops_by_pattern.py | 79 +++++++++++++++++++++-
2 files changed, 102 insertions(+), 5 deletions(-)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 8e4346e206..cf9bd0ac37 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -659,6 +659,16 @@ class OperatorFusor : public ExprMutator {
return sinfo->ret->IsInstance<TupleStructInfoNode>();
}
+ bool IsNestedTupleOutput(Function f) {
+ if (!IsTupleOutput(f)) return false;
+
+ auto tup =
GetStructInfo(f).as<FuncStructInfoNode>()->ret.as<TupleStructInfoNode>();
+ for (const auto& field : tup->fields) {
+ if (field->IsInstance<TupleStructInfoNode>()) return true;
+ }
+ return false;
+ }
+
BindingBlock VisitBindingBlock(const BindingBlock& block) final {
if (const auto* df_block = block.as<DataflowBlockNode>()) {
return VisitBindingBlock_(df_block);
@@ -722,7 +732,12 @@ class OperatorFusor : public ExprMutator {
// needs to be remapped to the output of TupleGetItem after the
corresponding tuple is
// emitted.
if (IsTupleOutput(func) && tuple_get_indices_.count(binding->var.get()))
{
- pending_tuple_get[group].push_back(binding->var);
+ if (!GetStructInfo(binding->var)->IsInstance<TupleStructInfoNode>() ||
+ IsNestedTupleOutput(func)) {
+ // When binding->var itself is a tuple, we do not need to remap this
variable to the
+ // output of TupleGetItem unless the output is a nested tuple.
+ pending_tuple_get[group].push_back(binding->var);
+ }
}
// Case 2. If the binding is not the last binding of the group, we skip
it.
@@ -751,7 +766,7 @@ class OperatorFusor : public ExprMutator {
}
// Step c. Update the mapping used for the remapping of the binding
variables.
- if (IsTupleOutput(func)) {
+ if (IsTupleOutput(func) && !pending_tuple_get.empty()) {
// If the output is a tuple, attach TupleGetItem to all tuple
elements, and
// remap variables approriately.
// The variables that need to be remapped and the corresponding tuple
indices are
@@ -1018,8 +1033,13 @@ class PatternBasedPartitioner : ExprVisitor {
ICHECK(parent_group);
parent_group->attrs.Set(attr::kComposite, pat_name_);
for (const auto& [pat, match] : matches_opt.value()) {
- // Put all matching call nodes into the parent group.
- if (pat->IsInstance<CallPatternNode>() && match != GetRef<Call>(call))
{
+ // Put all matching expressions into the parent group. But we need to
be careful not to
+ // merge expressions matched by a wildcard pattern, since a wildcard
can match an output of
+ // the previous group. For example, when there are two back-to-back
conv2d ops, the output
+ // of the first conv2d is matched to the input of the second conv2d
via a wildcard pattern.
+ // But we must avoid merging the first conv2d into the group of the
second conv2d.
+ if ((pat->IsInstance<CallPatternNode>() && match !=
GetRef<Call>(call)) ||
+ pat->IsInstance<TupleGetItemPatternNode>()) {
// Put the bound variable on the LHS into the same parent group.
AddToGroup(value_to_bound_var_[match], parent_group);
}
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 2f3e2d479f..146c8e1ebc 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -19,7 +19,12 @@ import pytest
import tvm
from tvm import relax
-from tvm.relax.dpl.pattern import is_op, make_fused_bias_activation_pattern,
wildcard
+from tvm.relax.dpl.pattern import (
+ is_op,
+ make_fused_bias_activation_pattern,
+ wildcard,
+ is_tuple_get_item,
+)
from tvm.relax.transform import PatternCheckContext
from tvm.script import ir as I
from tvm.script import relax as R
@@ -671,5 +676,77 @@ def test_bind_constants():
)
+def test_split():
+ @R.function
+ def func(inp: R.Tensor((16, 32), "float32")):
+ with R.dataflow():
+ tup = R.split(inp, [16], axis=1)
+ out = R.add(tup[0], tup[1])
+ R.output(out)
+ return out
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def fused_relax_split(
+ inp: R.Tensor((16, 32), dtype="float32")
+ ) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16),
dtype="float32")):
+ R.func_attr({"Composite": "x.split", "Primitive": 1})
+ with R.dataflow():
+ gv: R.Tuple(
+ R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16),
dtype="float32")
+ ) = R.split(inp, indices_or_sections=[16], axis=1)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(inp: R.Tensor((16, 32), dtype="float32")) -> R.Tensor((16,
16), dtype="float32"):
+ cls = Expected1
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16),
dtype="float32")
+ ) = cls.fused_relax_split(inp)
+ lv1: R.Tensor((16, 16), dtype="float32") = lv[0]
+ lv2: R.Tensor((16, 16), dtype="float32") = lv[1]
+ out: R.Tensor((16, 16), dtype="float32") = R.add(lv1, lv2)
+ R.output(out)
+ return out
+
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def fused_relax_split_relax_add(
+ inp: R.Tensor((16, 32), dtype="float32")
+ ) -> R.Tensor((16, 16), dtype="float32"):
+ R.func_attr({"Composite": "x.split", "Primitive": 1})
+ with R.dataflow():
+ tup: R.Tuple(
+ R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16),
dtype="float32")
+ ) = R.split(inp, indices_or_sections=[16], axis=1)
+ lv1: R.Tensor((16, 16), dtype="float32") = tup[0]
+ lv2: R.Tensor((16, 16), dtype="float32") = tup[1]
+ gv: R.Tensor((16, 16), dtype="float32") = R.add(lv1, lv2)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(inp: R.Tensor((16, 32), dtype="float32")) -> R.Tensor((16,
16), dtype="float32"):
+ cls = Expected2
+ with R.dataflow():
+ gv: R.Tensor((16, 16), dtype="float32") =
cls.fused_relax_split_relax_add(inp)
+ R.output(gv)
+ return gv
+
+ mod = tvm.IRModule({"main": func})
+
+ split = is_op("relax.split")(wildcard())
+ it1 = is_tuple_get_item(split, 0)
+ it2 = is_tuple_get_item(split, 1)
+ add = is_op("relax.add")(it1, it2)
+
+ check(mod, [("x.split", split)], Expected1)
+ check(mod, [("x.split", add)], Expected2)
+
+
if __name__ == "__main__":
pytest.main([__file__])