This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 06a48996cc [Unity][Fix] Fix `rms_norm` tests (#16109)
06a48996cc is described below

commit 06a48996ccff000327c3f3c6271e27dbf08b3afb
Author: Yaxing Cai <[email protected]>
AuthorDate: Sat Nov 11 08:06:45 2023 -0800

    [Unity][Fix] Fix `rms_norm` tests (#16109)
    
    This PR fixes the unit tests of `rms_norm` legalization.
---
 .../python/relax/test_transform_legalize_ops_nn.py | 289 +++++++++------------
 1 file changed, 126 insertions(+), 163 deletions(-)

diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 63e79c12cb..47f00dca87 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2756,27 +2756,32 @@ def test_rms_norm():
     @tvm.script.ir_module
     class Expected:
         @T.prim_func(private=True)
-        def rms_norm(
-            A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), 
"float32"),
-            B: T.Buffer((T.int64(4), T.int64(5)), "float32"),
-            T_rms_norm: T.Buffer(
-                (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"
-            ),
-        ):
+        def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)), "float32"), B: T.Buffer((T.int64(4), T.int64(5)), "float32"), 
T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")):
             T.func_attr({"tir.noalias": T.bool(True)})
             # with T.block("root"):
+            T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
             T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
             T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
-            for ax0, ax1, ax2, ax3 in T.grid(
-                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
-            ):
-                with T.block("T_multiply"):
+            T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_cast"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
                     T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, 
v_ax2, v_ax3]
+            for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+                with T.block("T_cast_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(B[v_ax0, v_ax1])
+                    T.writes(T_cast_2[v_ax0, v_ax1])
+                    T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
-                        A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2, 
v_ax3]
-                    )
+                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, 
v_ax1, v_ax2, v_ax3] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
             for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
                 with T.block("T_multiply_red"):
                     v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, 
k2, k3])
@@ -2784,40 +2789,24 @@ def test_rms_norm():
                     T.writes(T_multiply_red[v_ax0, v_ax1])
                     with T.init():
                         T_multiply_red[v_ax0, v_ax1] = T.float32(0)
-                    T_multiply_red[v_ax0, v_ax1] = (
-                        T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, 
v_ax1, v_k2, v_k3]
-                    )
-            for ax0, ax1, ax2, ax3 in T.grid(
-                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
-            ):
+                    T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, 
v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
                 with T.block("T_rms_norm"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(
-                        A[v_ax0, v_ax1, v_ax2, v_ax3],
-                        B[v_ax2, v_ax3],
-                        T_multiply_red[v_ax0, v_ax1],
-                    )
+                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], 
T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1])
                     T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
-                        A[v_ax0, v_ax1, v_ax2, v_ax3]
-                        * B[v_ax2, v_ax3]
-                        * T.rsqrt(
-                            T_multiply_red[v_ax0, v_ax1] * T.float32(0.05)
-                            + T.float32(1e-5)
-                        )
-                    )
+                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, 
v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0, 
v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_cast_2"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_cast[v_ax0, v_ax1, v_ax2, v_ax3] = T_rms_norm[v_ax0, 
v_ax1, v_ax2, v_ax3]
 
         @R.function
-        def main(
-            x: R.Tensor((2, 3, 4, 5), dtype="float32"),
-            weight: R.Tensor((4, 5), dtype="float32"),
-        ) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
+        def main(x: R.Tensor((2, 3, 4, 5), dtype="float32"), weight: 
R.Tensor((4, 5), dtype="float32")) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
             cls = Expected
-            gv = R.call_tir(
-                cls.rms_norm,
-                (x, weight),
-                out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32"),
-            )
+            gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 
3, 4, 5), dtype="float32"))
             return gv
     # fmt: on
     mod = LegalizeOps()(RMSNorm)
@@ -2836,70 +2825,57 @@ def test_rms_norm_fp16():
     @tvm.script.ir_module
     class Expected:
         @T.prim_func(private=True)
-        def rms_norm(
-            A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), 
"float16"),
-            B: T.Buffer((T.int64(4), T.int64(5)), "float16"),
-            T_rms_norm: T.Buffer(
-                (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"
-            ),
-        ):
+        def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)), "float16"), B: T.Buffer((T.int64(4), T.int64(5)), "float16"), 
T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16")):
             T.func_attr({"tir.noalias": T.bool(True)})
             # with T.block("root"):
-            T_multiply = T.alloc_buffer(
-                (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"
-            )
-            T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)), 
"float16")
-            for ax0, ax1, ax2, ax3 in T.grid(
-                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
-            ):
-                with T.block("T_multiply"):
+            T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
+            T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_cast"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
                     T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float32", 
A[v_ax0, v_ax1, v_ax2, v_ax3])
+            for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+                with T.block("T_cast_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(B[v_ax0, v_ax1])
+                    T.writes(T_cast_2[v_ax0, v_ax1])
+                    T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1])
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
-                        A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2, 
v_ax3]
-                    )
+                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, 
v_ax1, v_ax2, v_ax3] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
             for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
                 with T.block("T_multiply_red"):
                     v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, 
k2, k3])
                     T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
                     T.writes(T_multiply_red[v_ax0, v_ax1])
                     with T.init():
-                        T_multiply_red[v_ax0, v_ax1] = T.float16(0)
-                    T_multiply_red[v_ax0, v_ax1] = (
-                        T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, 
v_ax1, v_k2, v_k3]
-                    )
-            for ax0, ax1, ax2, ax3 in T.grid(
-                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
-            ):
+                        T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+                    T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, 
v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
                 with T.block("T_rms_norm"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(
-                        A[v_ax0, v_ax1, v_ax2, v_ax3],
-                        B[v_ax2, v_ax3],
-                        T_multiply_red[v_ax0, v_ax1],
-                    )
+                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], 
T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1])
                     T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
-                        A[v_ax0, v_ax1, v_ax2, v_ax3]
-                        * B[v_ax2, v_ax3]
-                        * T.rsqrt(
-                            T_multiply_red[v_ax0, v_ax1] / (T.float16(4) * 
T.float16(5))
-                            + T.float16(1e-5)
-                        )
-                    )
+                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, 
v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0, 
v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_cast_2"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_cast[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float16", 
T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
 
         @R.function
-        def main(
-            x: R.Tensor((2, 3, 4, 5), dtype="float16"),
-            weight: R.Tensor((4, 5), dtype="float16"),
-        ) -> R.Tensor((2, 3, 4, 5), dtype="float16"):
+        def main(x: R.Tensor((2, 3, 4, 5), dtype="float16"), weight: 
R.Tensor((4, 5), dtype="float16")) -> R.Tensor((2, 3, 4, 5), dtype="float16"):
             cls = Expected
-            gv = R.call_tir(
-                cls.rms_norm,
-                (x, weight),
-                out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float16"),
-            )
+            gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 
3, 4, 5), dtype="float16"))
             return gv
     # fmt: on
     mod = LegalizeOps()(RMSNorm)
@@ -2921,25 +2897,36 @@ def test_rms_norm_symbolic():
     @tvm.script.ir_module
     class Expected:
         @T.prim_func(private=True)
-        def rms_norm(
-            var_A: T.handle, var_B: T.handle, var_T_rms_norm: T.handle
-        ):
+        def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle):
             T.func_attr({"tir.noalias": T.bool(True)})
             n, s, f = T.int64(), T.int64(), T.int64()
             A = T.match_buffer(var_A, (n, s, f))
             B = T.match_buffer(var_B, (s, f))
-            T_rms_norm = T.match_buffer(var_T_rms_norm, (n, s, f))
+            T_cast = T.match_buffer(var_T_cast, (n, s, f))
             # with T.block("root"):
+            T_cast_1 = T.alloc_buffer((n, s, f))
+            T_cast_2 = T.alloc_buffer((s, f))
             T_multiply = T.alloc_buffer((n, s, f))
             T_multiply_red = T.alloc_buffer((n,))
+            T_rms_norm = T.alloc_buffer((n, s, f))
             for ax0, ax1, ax2 in T.grid(n, s, f):
-                with T.block("T_multiply"):
+                with T.block("T_cast"):
                     v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                     T.reads(A[v_ax0, v_ax1, v_ax2])
+                    T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
+                    T_cast_1[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2]
+            for ax0, ax1 in T.grid(s, f):
+                with T.block("T_cast_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(B[v_ax0, v_ax1])
+                    T.writes(T_cast_2[v_ax0, v_ax1])
+                    T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
+            for ax0, ax1, ax2 in T.grid(n, s, f):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
                     T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
-                    T_multiply[v_ax0, v_ax1, v_ax2] = (
-                        A[v_ax0, v_ax1, v_ax2] * A[v_ax0, v_ax1, v_ax2]
-                    )
+                    T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, 
v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
             for ax0, k1, k2 in T.grid(n, s, f):
                 with T.block("T_multiply_red"):
                     v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2])
@@ -2947,42 +2934,27 @@ def test_rms_norm_symbolic():
                     T.writes(T_multiply_red[v_ax0])
                     with T.init():
                         T_multiply_red[v_ax0] = T.float32(0)
-                    T_multiply_red[v_ax0] = (
-                        T_multiply_red[v_ax0] + T_multiply[v_ax0, v_k1, v_k2]
-                    )
+                    T_multiply_red[v_ax0] = T_multiply_red[v_ax0] + 
T_multiply[v_ax0, v_k1, v_k2]
             for ax0, ax1, ax2 in T.grid(n, s, f):
                 with T.block("T_rms_norm"):
                     v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                    T.reads(
-                        A[v_ax0, v_ax1, v_ax2],
-                        B[v_ax1, v_ax2],
-                        T_multiply_red[v_ax0],
-                    )
+                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax1, 
v_ax2], T_multiply_red[v_ax0])
                     T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
-                    T_rms_norm[v_ax0, v_ax1, v_ax2] = (
-                        A[v_ax0, v_ax1, v_ax2]
-                        * B[v_ax1, v_ax2]
-                        * T.rsqrt(
-                            T_multiply_red[v_ax0]
-                            / (T.Cast("float32", s) * T.Cast("float32", f))
-                            + T.float32(1e-5)
-                        )
-                    )
+                    T_rms_norm[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1, 
v_ax2] * T_cast_2[v_ax1, v_ax2] * T.rsqrt(T_multiply_red[v_ax0] / 
(T.Cast("float32", s) * T.Cast("float32", f)) + 
T.float32(1.0000000000000001e-05))
+            for ax0, ax1, ax2 in T.grid(n, s, f):
+                with T.block("T_cast_2"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
+                    T.writes(T_cast[v_ax0, v_ax1, v_ax2])
+                    T_cast[v_ax0, v_ax1, v_ax2] = T_rms_norm[v_ax0, v_ax1, 
v_ax2]
 
         @R.function
-        def main(
-            x: R.Tensor(("n", "s", "f"), dtype="float32"),
-            weight: R.Tensor(("s", "f"), dtype="float32"),
-        ) -> R.Tensor(("n", "s", "f"), dtype="float32"):
+        def main(x: R.Tensor(("n", "s", "f"), dtype="float32"), weight: 
R.Tensor(("s", "f"), dtype="float32")) -> R.Tensor(("n", "s", "f"), 
dtype="float32"):
             n = T.int64()
             s = T.int64()
             f = T.int64()
             cls = Expected
-            gv = R.call_tir(
-                cls.rms_norm,
-                (x, weight),
-                out_sinfo=R.Tensor((n, s, f), dtype="float32"),
-            )
+            gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((n, 
s, f), dtype="float32"))
             return gv
     # fmt: on
     mod = LegalizeOps()(RMSNorm)
@@ -3001,27 +2973,32 @@ def test_rms_norm_no_bias():
     @tvm.script.ir_module
     class Expected:
         @T.prim_func(private=True)
-        def rms_norm(
-            A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), 
"float32"),
-            B: T.Buffer((T.int64(4), T.int64(5)), "float32"),
-            T_rms_norm: T.Buffer(
-                (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"
-            ),
-        ):
+        def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)), "float32"), B: T.Buffer((T.int64(4), T.int64(5)), "float32"), 
T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")):
             T.func_attr({"tir.noalias": T.bool(True)})
             # with T.block("root"):
+            T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
             T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
             T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
-            for ax0, ax1, ax2, ax3 in T.grid(
-                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
-            ):
-                with T.block("T_multiply"):
+            T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_cast"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
                     T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1, 
v_ax2, v_ax3]
+            for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+                with T.block("T_cast_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(B[v_ax0, v_ax1])
+                    T.writes(T_cast_2[v_ax0, v_ax1])
+                    T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
-                        A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2, 
v_ax3]
-                    )
+                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, 
v_ax1, v_ax2, v_ax3] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
             for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
                 with T.block("T_multiply_red"):
                     v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, 
k2, k3])
@@ -3029,38 +3006,24 @@ def test_rms_norm_no_bias():
                     T.writes(T_multiply_red[v_ax0, v_ax1])
                     with T.init():
                         T_multiply_red[v_ax0, v_ax1] = T.float32(0)
-                    T_multiply_red[v_ax0, v_ax1] = (
-                        T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, 
v_ax1, v_k2, v_k3]
-                    )
-            for ax0, ax1, ax2, ax3 in T.grid(
-                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
-            ):
+                    T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0, 
v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
                 with T.block("T_rms_norm"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
-                    T.reads(
-                        A[v_ax0, v_ax1, v_ax2, v_ax3],
-                        B[v_ax2, v_ax3],
-                        T_multiply_red[v_ax0, v_ax1],
-                    )
+                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3], 
T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1])
                     T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
-                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
-                        A[v_ax0, v_ax1, v_ax2, v_ax3]
-                        * B[v_ax2, v_ax3]
-                        * T.rsqrt(
-                            T_multiply_red[v_ax0, v_ax1] * T.float32(0.05)
-                            + T.float32(1e-05)
-                        )
-                    )
+                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0, 
v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0, 
v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(5)):
+                with T.block("T_cast_2"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_cast[v_ax0, v_ax1, v_ax2, v_ax3] = T_rms_norm[v_ax0, 
v_ax1, v_ax2, v_ax3]
 
         @R.function
-        def main(
-            x: R.Tensor((2, 3, 4, 5), dtype="float32"),
-            weight: R.Tensor((4, 5), dtype="float32"),
-        ) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
+        def main(x: R.Tensor((2, 3, 4, 5), dtype="float32"), weight: 
R.Tensor((4, 5), dtype="float32")) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
             cls = Expected
-            gv = R.call_tir(
-                cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5), 
dtype="float32")
-            )
+            gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 
3, 4, 5), dtype="float32"))
             return gv
     # fmt: on
     mod = LegalizeOps()(RMSNorm)

Reply via email to