This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0ce19d120c [Unity] Improve implementation of FuseOps (#14229)
0ce19d120c is described below

commit 0ce19d120cca67c06a426238e973fbe7157503a0
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Mar 7 19:57:23 2023 -0500

    [Unity] Improve implementation of FuseOps (#14229)
    
    This PR improves implementation of fuse-ops so it is more deterministic.
    
    Co-authored-by: Ruihang Lai <[email protected]>
---
 src/relax/transform/fuse_ops.cc               |  18 ++-
 tests/python/relax/test_transform_fuse_ops.py | 159 ++++++++++++++++++++++++++
 2 files changed, 171 insertions(+), 6 deletions(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index d6013c8874..3b6b3c17ac 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -755,10 +755,12 @@ class OperatorFusor : public ExprMutator {
           // Only check those group defined before.
           // Skip the vars from input or groups with single binding.
           if (producer_group != cur_group) {
-            ICHECK(!group_deps_[producer_group].count(cur_group))
-                << "A cyclic dependency detected between the groups " << 
binding->var->name_hint()
-                << " and " << used_var->name_hint() << " are in.";
-            group_deps_[cur_group].insert(producer_group);
+            for (Group* depgroup : group_deps_[producer_group]) {
+              ICHECK(depgroup != cur_group)
+                  << "A cyclic dependency detected between the groups " << 
binding->var->name_hint()
+                  << " and " << used_var->name_hint() << " are in.";
+            }
+            group_deps_[cur_group].push_back(producer_group);
           }
 
           if (auto producer = group2func_.find(producer_group);
@@ -865,8 +867,12 @@ class OperatorFusor : public ExprMutator {
   /*! \brief Record the index for TupleGetItem if the variable needs to be 
remapped to an output
    * tuple element after fusion. */
   std::unordered_map<const VarNode*, int> tuple_get_indices_;
-  /*! \brief A map from a group to its dependent groups, used to detect cyclic 
dependencies. */
-  std::unordered_map<Group*, std::unordered_set<Group*>> group_deps_;
+  /*!
+   * \brief A map from a group to its dependent groups, used to detect cyclic 
dependencies.
+   * \note Use vector so we can be deterministic, there won't be a lot of dep 
groups so
+   *       linear search is OK.
+   */
+  std::unordered_map<Group*, std::vector<Group*>> group_deps_;
   /*! \brief Whether or not to lift bound constants to parameters of the 
grouped function. */
   bool lift_constants_{true};
 };
diff --git a/tests/python/relax/test_transform_fuse_ops.py 
b/tests/python/relax/test_transform_fuse_ops.py
index a7a6066c4b..33d57417cf 100644
--- a/tests/python/relax/test_transform_fuse_ops.py
+++ b/tests/python/relax/test_transform_fuse_ops.py
@@ -956,5 +956,164 @@ def test_layer_norm_silu():
     _check(Module, Expected)
 
 
+def test_multiple_paths():
+    # fmt: off
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"),
+            inp_1: R.Tensor((2, 1280), dtype="float32"),
+            w1: R.Tensor((320, 320, 3, 3), dtype="float32"),
+            b1: R.Tensor((320,), "float32"),
+            w2: R.Tensor((320, 1280), "float32"),
+            b2: R.Tensor((320,), "float32"),
+        ):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv27: R.Tensor((2, 320, 64, 64), dtype="float32") = 
R.nn.conv2d(inp_0, w1, strides=[1, 1], padding=[1, 1, 1, 1], dilation=[1, 1], 
groups=1, data_layout="NCHW", kernel_layout="OIHW", out_layout="NCHW", 
out_dtype="float32")
+                lv28: R.Tensor((1, 320, 1, 1), dtype="float32") = 
R.reshape(b1, R.shape([1, 320, 1, 1]))  ##
+                lv29: R.Tensor((2, 320, 64, 64), dtype="float32") = 
R.add(lv27, lv28)
+                lv31: R.Tensor((1280, 320), dtype="float32") = 
R.permute_dims(w2, axes=None)  ##
+                lv32: R.Tensor((2, 320), dtype="float32") = R.matmul(inp_1, 
lv31, out_dtype="float32")
+                lv33: R.Tensor((2, 320), dtype="float32") = R.add(lv32, b2)
+                lv35: R.Tensor((2, 320, 1, 1), dtype="float32") = 
R.reshape(lv33, R.shape([2, 320, 1, 1]))
+                lv36: R.Tensor((2, 320, 64, 64), dtype="float32") = 
R.add(lv29, lv35)
+                gv = lv36
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), 
T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1), 
T.int64(320), T.int64(1), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(2), 
T.int64(320), T.int64(64), T.int64(64)), "float32")):
+            T.func_attr({"op_pattern": 0, "tir.noalias": True})
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320), 
T.int64(64), T.int64(64)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], 
rxplaceholder_1[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+                    T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, 
v_ax1, v_ax2, v_ax3] + rxplaceholder_1[T.int64(0), v_ax1, T.int64(0), 
T.int64(0)]
+
+        @T.prim_func
+        def add1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)), 
"float32"), rxplaceholder_1: T.Buffer((T.int64(320),), "float32"), T_add: 
T.Buffer((T.int64(2), T.int64(320)), "float32")):
+            T.func_attr({"op_pattern": 0, "tir.noalias": True})
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(320)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax0, v_ax1], 
rxplaceholder_1[v_ax1])
+                    T.writes(T_add[v_ax0, v_ax1])
+                    T_add[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] + 
rxplaceholder_1[v_ax1]
+
+        @T.prim_func
+        def add2(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), 
T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), 
T.int64(320), T.int64(1), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(2), 
T.int64(320), T.int64(64), T.int64(64)), "float32")):
+            T.func_attr({"op_pattern": 0, "tir.noalias": True})
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320), 
T.int64(64), T.int64(64)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3], 
rxplaceholder_1[v_ax0, v_ax1, T.int64(0), T.int64(0)])
+                    T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax0, 
v_ax1, v_ax2, v_ax3] + rxplaceholder_1[v_ax0, v_ax1, T.int64(0), T.int64(0)]
+
+        @T.prim_func
+        def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), 
T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(320), 
T.int64(320), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: 
T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32")):
+            T.func_attr({"op_pattern": 4, "tir.noalias": True})
+            pad_temp = T.alloc_buffer((T.int64(2), T.int64(320), T.int64(66), 
T.int64(66)))
+            for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(320), 
T.int64(66), T.int64(66)):
+                with T.block("pad_temp"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1), v_i3 
- T.int64(1)])
+                    T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
+                    pad_temp[v_i0, v_i1, v_i2, v_i3] = 
T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(65) and T.int64(1) <= v_i3 
and v_i3 < T.int64(65), rxplaceholder[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - 
T.int64(1)], T.float32(0))
+            for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(2), T.int64(320), 
T.int64(64), T.int64(64), T.int64(320), T.int64(3), T.int64(3)):
+                with T.block("conv2d_nchw"):
+                    v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = 
T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
+                    T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], 
rxplaceholder_1[v_ff, v_rc, v_ry, v_rx])
+                    T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
+                    with T.init():
+                        conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0)
+                    conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, 
v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * 
rxplaceholder_1[v_ff, v_rc, v_ry, v_rx]
+
+        @T.prim_func
+        def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(1280)), 
"float32"), rxplaceholder_1: T.Buffer((T.int64(1280), T.int64(320)), 
"float32"), matmul: T.Buffer((T.int64(2), T.int64(320)), "float32")):
+            T.func_attr({"op_pattern": 4, "tir.noalias": True})
+            for i0, i1, k in T.grid(T.int64(2), T.int64(320), T.int64(1280)):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+                    T.reads(rxplaceholder[v_i0, v_k], rxplaceholder_1[v_k, 
v_i1])
+                    T.writes(matmul[v_i0, v_i1])
+                    with T.init():
+                        matmul[v_i0, v_i1] = T.float32(0)
+                    matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + 
rxplaceholder[v_i0, v_k] * rxplaceholder_1[v_k, v_i1]
+
+        @T.prim_func
+        def reshape(rxplaceholder: T.Buffer((T.int64(320),), "float32"), 
T_reshape: T.Buffer((T.int64(1), T.int64(320), T.int64(1), T.int64(1)), 
"float32")):
+            T.func_attr({"op_pattern": 2, "tir.noalias": True})
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(320), 
T.int64(1), T.int64(1)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[(v_ax1 + v_ax2 + v_ax3) % 
T.int64(320)])
+                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[(v_ax1 + v_ax2 + v_ax3) % T.int64(320)]
+
+        @T.prim_func
+        def reshape1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)), 
"float32"), T_reshape: T.Buffer((T.int64(2), T.int64(320), T.int64(1), 
T.int64(1)), "float32")):
+            T.func_attr({"op_pattern": 2, "tir.noalias": True})
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320), 
T.int64(1), T.int64(1)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[((v_ax1 + v_ax2 + v_ax3) // 
T.int64(320) + v_ax0) % T.int64(2), (v_ax1 + v_ax2 + v_ax3) % T.int64(320)])
+                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[((v_ax1 + v_ax2 + v_ax3) // T.int64(320) + v_ax0) % T.int64(2), 
(v_ax1 + v_ax2 + v_ax3) % T.int64(320)]
+
+        @T.prim_func
+        def transpose(rxplaceholder: T.Buffer((T.int64(320), T.int64(1280)), 
"float32"), T_transpose: T.Buffer((T.int64(1280), T.int64(320)), "float32")):
+            T.func_attr({"op_pattern": 2, "tir.noalias": True})
+            for ax0, ax1 in T.grid(T.int64(1280), T.int64(320)):
+                with T.block("T_transpose"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax1, v_ax0])
+                    T.writes(T_transpose[v_ax0, v_ax1])
+                    T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0]
+
+        @R.function
+        def fused_conv2d_add_add2(inp_0: R.Tensor((2, 320, 64, 64), 
dtype="float32"), w1: R.Tensor((320, 320, 3, 3), dtype="float32"), lv28: 
R.Tensor((1, 320, 1, 1), dtype="float32"), lv35: R.Tensor((2, 320, 1, 1), 
dtype="float32")) -> R.Tensor((2, 320, 64, 64), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            with R.dataflow():
+                lv27 = R.call_tir(conv2d, (inp_0, w1), out_sinfo=R.Tensor((2, 
320, 64, 64), dtype="float32"))
+                lv29 = R.call_tir(add, (lv27, lv28), out_sinfo=R.Tensor((2, 
320, 64, 64), dtype="float32"))
+                gv = R.call_tir(add2, (lv29, lv35), out_sinfo=R.Tensor((2, 
320, 64, 64), dtype="float32"))
+                R.output(gv)
+            return gv
+
+        @R.function
+        def fused_matmul_add1(inp_1: R.Tensor((2, 1280), dtype="float32"), 
lv31: R.Tensor((1280, 320), dtype="float32"), b2: R.Tensor((320,), 
dtype="float32")) -> R.Tensor((2, 320), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            with R.dataflow():
+                lv32 = R.call_tir(matmul, (inp_1, lv31), 
out_sinfo=R.Tensor((2, 320), dtype="float32"))
+                gv = R.call_tir(add1, (lv32, b2), out_sinfo=R.Tensor((2, 320), 
dtype="float32"))
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), inp_1: 
R.Tensor((2, 1280), dtype="float32"), w1: R.Tensor((320, 320, 3, 3), 
dtype="float32"), b1: R.Tensor((320,), dtype="float32"), w2: R.Tensor((320, 
1280), dtype="float32"), b2: R.Tensor((320,), dtype="float32")) -> R.Tensor((2, 
320, 64, 64), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv28 = R.call_tir(reshape, (b1,), out_sinfo=R.Tensor((1, 320, 
1, 1), dtype="float32"))
+                lv31 = R.call_tir(transpose, (w2,), out_sinfo=R.Tensor((1280, 
320), dtype="float32"))
+                lv: R.Tensor((2, 320), dtype="float32") = 
fused_matmul_add1(inp_1, lv31, b2)
+                lv35 = R.call_tir(reshape1, (lv,), out_sinfo=R.Tensor((2, 320, 
1, 1), dtype="float32"))
+                lv1: R.Tensor((2, 320, 64, 64), dtype="float32") = 
fused_conv2d_add_add2(inp_0, w1, lv28, lv35)
+                gv: R.Tensor((2, 320, 64, 64), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+    # fmt: on
+
+    mod = relax.transform.LegalizeOps()(Module)
+    mod = relax.transform.AnnotateTIROpPattern()(mod)
+    mod = relax.transform.FuseOps()(mod)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to