Lunderberg commented on code in PR #16487:
URL: https://github.com/apache/tvm/pull/16487#discussion_r1478397845


##########
tests/python/relax/test_transform_fuse_tir.py:
##########
@@ -1930,5 +1930,251 @@ def main(
     _check(Before, After)
 
 
+def test_inplace_simple():
+    @I.ir_module
+    class Module:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def add_inplace(
+            A: T.Buffer((T.int64(10), T.int64(20)), "float32"), B: 
T.Buffer((), "float32")
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1], B[()])
+                    T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+        @T.prim_func(private=True)
+        def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(A[v_i0, v_i1])
+                    T.writes(A[v_i0, v_i1])
+                    A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+        @T.prim_func(private=True)
+        def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+        @R.function(private=True)
+        def fused_add_exp_squeeze(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Module
+            with R.dataflow():
+                # this overwrites x and is actually evil but we are doing it 
just to test the pass
+                lv = R.call_tir_inplace(
+                    cls.add_inplace,
+                    (x, p0),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                lv1 = R.call_tir_inplace(
+                    cls.exp_inplace,
+                    (lv,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                gv = R.call_tir_inplace(
+                    cls.squeeze_inplace,
+                    (lv1,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Module
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = 
cls.fused_add_exp_squeeze(x, p0)
+                R.output(gv1)
+            return gv1
+
+    @I.ir_module
+    class Expected:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def fused_add_exp_squeeze(
+            x: T.Buffer((T.int64(10), T.int64(20)), "float32"), p0: 
T.Buffer((), "float32")
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1], p0[()])
+                    T.writes(x[v_ax0, v_ax1])
+                    x[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(x[v_i0, v_i1])
+                    T.writes(x[v_i0, v_i1])
+                    x[v_i0, v_i1] = T.exp(x[v_i0, v_i1])
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1])
+                    T.writes(x[v_ax0, v_ax1])
+                    x[v_ax0, v_ax1] = x[v_ax0, v_ax1]
+
+        # note that this will clobber x! Use with caution
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir_inplace(
+                    cls.fused_add_exp_squeeze,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                    inplace_indices=[0],
+                )
+                R.output(gv1)
+            return gv1
+
+    _check(Module, Expected)
+
+
+def test_fuse_inplace_and_non_inplace():
+    @I.ir_module
+    class Module:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def add(
+            A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+            B: T.Buffer((), "float32"),
+            Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1], B[()])
+                    T.writes(Out[v_ax0, v_ax1])
+                    Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
+
+        @T.prim_func(private=True)
+        def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(A[v_i0, v_i1])
+                    T.writes(A[v_i0, v_i1])
+                    A[v_i0, v_i1] = T.exp(A[v_i0, v_i1])
+
+        @T.prim_func(private=True)
+        def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), 
"float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(A[v_ax0, v_ax1])
+                    A[v_ax0, v_ax1] = A[v_ax0, v_ax1]
+
+        @R.function(private=True)
+        def fused_add_exp_squeeze(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Module
+            with R.dataflow():
+                lv = R.call_tir(
+                    cls.add,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                lv1 = R.call_tir_inplace(
+                    cls.exp_inplace,
+                    (lv,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                gv = R.call_tir_inplace(
+                    cls.squeeze_inplace,
+                    (lv1,),
+                    inplace_indices=[0],
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Module
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = 
cls.fused_add_exp_squeeze(x, p0)
+                R.output(gv1)
+            return gv1
+
+    @I.ir_module
+    class Expected:
+        I.module_attrs({"foo": "bar"})
+
+        @T.prim_func(private=True)
+        def fused_add_exp_squeeze(
+            x: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+            p0: T.Buffer((), "float32"),
+            p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[v_ax0, v_ax1], p0[()])
+                    T.writes(p_output0[v_ax0, v_ax1])
+                    p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
+            for i0, i1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(p_output0[v_i0, v_i1])
+                    T.writes(p_output0[v_i0, v_i1])
+                    p_output0[v_i0, v_i1] = T.exp(p_output0[v_i0, v_i1])
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
+                with T.block("T_squeeze"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(p_output0[v_ax0, v_ax1])
+                    T.writes(p_output0[v_ax0, v_ax1])
+                    p_output0[v_ax0, v_ax1] = p_output0[v_ax0, v_ax1]
+
+        @R.function
+        def main(
+            x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), 
dtype="float32")
+        ) -> R.Tensor((10, 20), dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir(
+                    cls.fused_add_exp_squeeze,
+                    (x, p0),
+                    out_sinfo=R.Tensor((10, 20), dtype="float32"),
+                )
+                R.output(gv1)
+            return gv1
+
+    _check(Module, Expected)
+
+

Review Comment:
   Thank you, and yes, that is what I had in mind.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to