This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0ce19d120c [Unity] Improve implementation of FuseOps (#14229)
0ce19d120c is described below
commit 0ce19d120cca67c06a426238e973fbe7157503a0
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Mar 7 19:57:23 2023 -0500
[Unity] Improve implementation of FuseOps (#14229)
This PR improves implementation of fuse-ops so it is more deterministic.
Co-authored-by: Ruihang Lai <[email protected]>
---
src/relax/transform/fuse_ops.cc | 18 ++-
tests/python/relax/test_transform_fuse_ops.py | 159 ++++++++++++++++++++++++++
2 files changed, 171 insertions(+), 6 deletions(-)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index d6013c8874..3b6b3c17ac 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -755,10 +755,12 @@ class OperatorFusor : public ExprMutator {
// Only check those group defined before.
// Skip the vars from input or groups with single binding.
if (producer_group != cur_group) {
- ICHECK(!group_deps_[producer_group].count(cur_group))
- << "A cyclic dependency detected between the groups " <<
binding->var->name_hint()
- << " and " << used_var->name_hint() << " are in.";
- group_deps_[cur_group].insert(producer_group);
+ for (Group* depgroup : group_deps_[producer_group]) {
+ ICHECK(depgroup != cur_group)
+ << "A cyclic dependency detected between the groups " <<
binding->var->name_hint()
+ << " and " << used_var->name_hint() << " are in.";
+ }
+ group_deps_[cur_group].push_back(producer_group);
}
if (auto producer = group2func_.find(producer_group);
@@ -865,8 +867,12 @@ class OperatorFusor : public ExprMutator {
/*! \brief Record the index for TupleGetItem if the variable needs to be
remapped to an output
* tuple element after fusion. */
std::unordered_map<const VarNode*, int> tuple_get_indices_;
- /*! \brief A map from a group to its dependent groups, used to detect cyclic
dependencies. */
- std::unordered_map<Group*, std::unordered_set<Group*>> group_deps_;
+ /*!
+ * \brief A map from a group to its dependent groups, used to detect cyclic
dependencies.
+ * \note Use vector so we can be deterministic, there won't be a lot of dep
groups so
+ * linear search is OK.
+ */
+ std::unordered_map<Group*, std::vector<Group*>> group_deps_;
/*! \brief Whether or not to lift bound constants to parameters of the
grouped function. */
bool lift_constants_{true};
};
diff --git a/tests/python/relax/test_transform_fuse_ops.py
b/tests/python/relax/test_transform_fuse_ops.py
index a7a6066c4b..33d57417cf 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -956,5 +956,164 @@ def test_layer_norm_silu():
_check(Module, Expected)
+def test_multiple_paths():
+ # fmt: off
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"),
+ inp_1: R.Tensor((2, 1280), dtype="float32"),
+ w1: R.Tensor((320, 320, 3, 3), dtype="float32"),
+ b1: R.Tensor((320,), "float32"),
+ w2: R.Tensor((320, 1280), "float32"),
+ b2: R.Tensor((320,), "float32"),
+ ):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv27: R.Tensor((2, 320, 64, 64), dtype="float32") =
R.nn.conv2d(inp_0, w1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1],
groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW",
out_dtype="float32")
+ lv28: R.Tensor((1, 320, 1, 1), dtype="float32") =
R.reshape(b1, R.shape([1, 320, 1, 1])) ##
+ lv29: R.Tensor((2, 320, 64, 64), dtype="float32") =
R.add(lv27, lv28)
+ lv31: R.Tensor((1280, 320), dtype="float32") =
R.permute_dims(w2, axes=None) ##
+ lv32: R.Tensor((2, 320), dtype="float32") = R.matmul(inp_1,
lv31, out_dtype="float32")
+ lv33: R.Tensor((2, 320), dtype="float32") = R.add(lv32, b2)
+ lv35: R.Tensor((2, 320, 1, 1), dtype="float32") =
R.reshape(lv33, R.shape([2, 320, 1, 1]))
+ lv36: R.Tensor((2, 320, 64, 64), dtype="float32") =
R.add(lv29, lv35)
+ gv = lv36
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(320),
T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1),
T.int64(320), T.int64(1), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(2),
T.int64(320), T.int64(64), T.int64(64)), "float32")):
+ T.func_attr({"op_pattern": 0, "tir.noalias": True})
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320),
T.int64(64), T.int64(64)):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
rxplaceholder_1[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_add[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0,
v_ax1, v_ax2, v_ax3] + rxplaceholder_1[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+
+ @T.prim_func
+ def add1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)),
"float32"), rxplaceholder_1: T.Buffer((T.int64(320),), "float32"), T_add:
T.Buffer((T.int64(2), T.int64(320)), "float32")):
+ T.func_attr({"op_pattern": 0, "tir.noalias": True})
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(320)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax0, v_ax1],
rxplaceholder_1[v_ax1])
+ T.writes(T_add[v_ax0, v_ax1])
+ T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] +
rxplaceholder_1[v_ax1]
+
+ @T.prim_func
+ def add2(rxplaceholder: T.Buffer((T.int64(2), T.int64(320),
T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2),
T.int64(320), T.int64(1), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(2),
T.int64(320), T.int64(64), T.int64(64)), "float32")):
+ T.func_attr({"op_pattern": 0, "tir.noalias": True})
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320),
T.int64(64), T.int64(64)):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
rxplaceholder_1[v_ax0, v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_add[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0,
v_ax1, v_ax2, v_ax3] + rxplaceholder_1[v_ax0, v_ax1, T.int64(0), T.int64(0)]
+
+ @T.prim_func
+ def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(320),
T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(320),
T.int64(320), T.int64(3), T.int64(3)), "float32"), conv2d_nchw:
T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32")):
+ T.func_attr({"op_pattern": 4, "tir.noalias": True})
+ pad_temp = T.alloc_buffer((T.int64(2), T.int64(320), T.int64(66),
T.int64(66)))
+ for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(320),
T.int64(66), T.int64(66)):
+ with T.block("pad_temp"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ T.reads(rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1), v_i3
- T.int64(1)])
+ T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
+ pad_temp[v_i0, v_i1, v_i2, v_i3] =
T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(65) and T.int64(1) <= v_i3
and v_i3 < T.int64(65), rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1), v_i3 -
T.int64(1)], T.float32(0))
+ for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(2), T.int64(320),
T.int64(64), T.int64(64), T.int64(320), T.int64(3), T.int64(3)):
+ with T.block("conv2d_nchw"):
+ v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx =
T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
+ T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx],
rxplaceholder_1[v_ff, v_rc, v_ry, v_rx])
+ T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
+ with T.init():
+ conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0)
+ conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn,
v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] *
rxplaceholder_1[v_ff, v_rc, v_ry, v_rx]
+
+ @T.prim_func
+ def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(1280)),
"float32"), rxplaceholder_1: T.Buffer((T.int64(1280), T.int64(320)),
"float32"), matmul: T.Buffer((T.int64(2), T.int64(320)), "float32")):
+ T.func_attr({"op_pattern": 4, "tir.noalias": True})
+ for i0, i1, k in T.grid(T.int64(2), T.int64(320), T.int64(1280)):
+ with T.block("matmul"):
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+ T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k,
v_i1])
+ T.writes(matmul[v_i0, v_i1])
+ with T.init():
+ matmul[v_i0, v_i1] = T.float32(0)
+ matmul[v_i0, v_i1] = matmul[v_i0, v_i1] +
rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1]
+
+ @T.prim_func
+ def reshape(rxplaceholder: T.Buffer((T.int64(320),), "float32"),
T_reshape: T.Buffer((T.int64(1), T.int64(320), T.int64(1), T.int64(1)),
"float32")):
+ T.func_attr({"op_pattern": 2, "tir.noalias": True})
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(320),
T.int64(1), T.int64(1)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[(v_ax1 + v_ax2 + v_ax3) %
T.int64(320)])
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[(v_ax1 + v_ax2 + v_ax3) % T.int64(320)]
+
+ @T.prim_func
+ def reshape1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)),
"float32"), T_reshape: T.Buffer((T.int64(2), T.int64(320), T.int64(1),
T.int64(1)), "float32")):
+ T.func_attr({"op_pattern": 2, "tir.noalias": True})
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320),
T.int64(1), T.int64(1)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[((v_ax1 + v_ax2 + v_ax3) //
T.int64(320) + v_ax0) % T.int64(2), (v_ax1 + v_ax2 + v_ax3) % T.int64(320)])
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[((v_ax1 + v_ax2 + v_ax3) // T.int64(320) + v_ax0) % T.int64(2),
(v_ax1 + v_ax2 + v_ax3) % T.int64(320)]
+
+ @T.prim_func
+ def transpose(rxplaceholder: T.Buffer((T.int64(320), T.int64(1280)),
"float32"), T_transpose: T.Buffer((T.int64(1280), T.int64(320)), "float32")):
+ T.func_attr({"op_pattern": 2, "tir.noalias": True})
+ for ax0, ax1 in T.grid(T.int64(1280), T.int64(320)):
+ with T.block("T_transpose"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax1, v_ax0])
+ T.writes(T_transpose[v_ax0, v_ax1])
+ T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
+
+ @R.function
+ def fused_conv2d_add_add2(inp_0: R.Tensor((2, 320, 64, 64),
dtype="float32"), w1: R.Tensor((320, 320, 3, 3), dtype="float32"), lv28:
R.Tensor((1, 320, 1, 1), dtype="float32"), lv35: R.Tensor((2, 320, 1, 1),
dtype="float32")) -> R.Tensor((2, 320, 64, 64), dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ with R.dataflow():
+ lv27 = R.call_tir(conv2d, (inp_0, w1), out_sinfo=R.Tensor((2,
320, 64, 64), dtype="float32"))
+ lv29 = R.call_tir(add, (lv27, lv28), out_sinfo=R.Tensor((2,
320, 64, 64), dtype="float32"))
+ gv = R.call_tir(add2, (lv29, lv35), out_sinfo=R.Tensor((2,
320, 64, 64), dtype="float32"))
+ R.output(gv)
+ return gv
+
+ @R.function
+ def fused_matmul_add1(inp_1: R.Tensor((2, 1280), dtype="float32"),
lv31: R.Tensor((1280, 320), dtype="float32"), b2: R.Tensor((320,),
dtype="float32")) -> R.Tensor((2, 320), dtype="float32"):
+ R.func_attr({"Primitive": 1})
+ with R.dataflow():
+ lv32 = R.call_tir(matmul, (inp_1, lv31),
out_sinfo=R.Tensor((2, 320), dtype="float32"))
+ gv = R.call_tir(add1, (lv32, b2), out_sinfo=R.Tensor((2, 320),
dtype="float32"))
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), inp_1:
R.Tensor((2, 1280), dtype="float32"), w1: R.Tensor((320, 320, 3, 3),
dtype="float32"), b1: R.Tensor((320,), dtype="float32"), w2: R.Tensor((320,
1280), dtype="float32"), b2: R.Tensor((320,), dtype="float32")) -> R.Tensor((2,
320, 64, 64), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv28 = R.call_tir(reshape, (b1,), out_sinfo=R.Tensor((1, 320,
1, 1), dtype="float32"))
+ lv31 = R.call_tir(transpose, (w2,), out_sinfo=R.Tensor((1280,
320), dtype="float32"))
+ lv: R.Tensor((2, 320), dtype="float32") =
fused_matmul_add1(inp_1, lv31, b2)
+ lv35 = R.call_tir(reshape1, (lv,), out_sinfo=R.Tensor((2, 320,
1, 1), dtype="float32"))
+ lv1: R.Tensor((2, 320, 64, 64), dtype="float32") =
fused_conv2d_add_add2(inp_0, w1, lv28, lv35)
+ gv: R.Tensor((2, 320, 64, 64), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+ # fmt: on
+
+ mod = relax.transform.LegalizeOps()(Module)
+ mod = relax.transform.AnnotateTIROpPattern()(mod)
+ mod = relax.transform.FuseOps()(mod)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()