This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 06a48996cc [Unity][Fix] Fix `rms_norm` tests (#16109)
06a48996cc is described below
commit 06a48996ccff000327c3f3c6271e27dbf08b3afb
Author: Yaxing Cai <[email protected]>
AuthorDate: Sat Nov 11 08:06:45 2023 -0800
[Unity][Fix] Fix `rms_norm` tests (#16109)
This PR fixes the unit tests of `rms_norm` legalization.
---
.../python/relax/test_transform_legalize_ops_nn.py | 289 +++++++++------------
1 file changed, 126 insertions(+), 163 deletions(-)
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 63e79c12cb..47f00dca87 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2756,27 +2756,32 @@ def test_rms_norm():
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
- def rms_norm(
- A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
"float32"),
- B: T.Buffer((T.int64(4), T.int64(5)), "float32"),
- T_rms_norm: T.Buffer(
- (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"
- ),
- ):
+ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "float32"), B: T.Buffer((T.int64(4), T.int64(5)), "float32"),
T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
+ T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
- for ax0, ax1, ax2, ax3 in T.grid(
- T.int64(2), T.int64(3), T.int64(4), T.int64(5)
- ):
- with T.block("T_multiply"):
+ T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_cast"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1,
v_ax2, v_ax3]
+ for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+ with T.block("T_cast_1"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(B[v_ax0, v_ax1])
+ T.writes(T_cast_2[v_ax0, v_ax1])
+ T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
- T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
- A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2,
v_ax3]
- )
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0,
v_ax1, v_ax2, v_ax3] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
with T.block("T_multiply_red"):
v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1,
k2, k3])
@@ -2784,40 +2789,24 @@ def test_rms_norm():
T.writes(T_multiply_red[v_ax0, v_ax1])
with T.init():
T_multiply_red[v_ax0, v_ax1] = T.float32(0)
- T_multiply_red[v_ax0, v_ax1] = (
- T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0,
v_ax1, v_k2, v_k3]
- )
- for ax0, ax1, ax2, ax3 in T.grid(
- T.int64(2), T.int64(3), T.int64(4), T.int64(5)
- ):
+ T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0,
v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_rms_norm"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(
- A[v_ax0, v_ax1, v_ax2, v_ax3],
- B[v_ax2, v_ax3],
- T_multiply_red[v_ax0, v_ax1],
- )
+ T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3],
T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1])
T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
- T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
- A[v_ax0, v_ax1, v_ax2, v_ax3]
- * B[v_ax2, v_ax3]
- * T.rsqrt(
- T_multiply_red[v_ax0, v_ax1] * T.float32(0.05)
- + T.float32(1e-5)
- )
- )
+ T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0,
v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0,
v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_cast_2"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_cast[v_ax0, v_ax1, v_ax2, v_ax3] = T_rms_norm[v_ax0,
v_ax1, v_ax2, v_ax3]
@R.function
- def main(
- x: R.Tensor((2, 3, 4, 5), dtype="float32"),
- weight: R.Tensor((4, 5), dtype="float32"),
- ) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
+ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32"), weight:
R.Tensor((4, 5), dtype="float32")) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
cls = Expected
- gv = R.call_tir(
- cls.rms_norm,
- (x, weight),
- out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32"),
- )
+ gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2,
3, 4, 5), dtype="float32"))
return gv
# fmt: on
mod = LegalizeOps()(RMSNorm)
@@ -2836,70 +2825,57 @@ def test_rms_norm_fp16():
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
- def rms_norm(
- A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
"float16"),
- B: T.Buffer((T.int64(4), T.int64(5)), "float16"),
- T_rms_norm: T.Buffer(
- (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"
- ),
- ):
+ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "float16"), B: T.Buffer((T.int64(4), T.int64(5)), "float16"),
T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
- T_multiply = T.alloc_buffer(
- (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"
- )
- T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)),
"float16")
- for ax0, ax1, ax2, ax3 in T.grid(
- T.int64(2), T.int64(3), T.int64(4), T.int64(5)
- ):
- with T.block("T_multiply"):
+ T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
+ T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+ T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_cast"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float32",
A[v_ax0, v_ax1, v_ax2, v_ax3])
+ for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+ with T.block("T_cast_1"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(B[v_ax0, v_ax1])
+ T.writes(T_cast_2[v_ax0, v_ax1])
+ T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1])
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
- T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
- A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2,
v_ax3]
- )
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0,
v_ax1, v_ax2, v_ax3] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
with T.block("T_multiply_red"):
v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1,
k2, k3])
T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
T.writes(T_multiply_red[v_ax0, v_ax1])
with T.init():
- T_multiply_red[v_ax0, v_ax1] = T.float16(0)
- T_multiply_red[v_ax0, v_ax1] = (
- T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0,
v_ax1, v_k2, v_k3]
- )
- for ax0, ax1, ax2, ax3 in T.grid(
- T.int64(2), T.int64(3), T.int64(4), T.int64(5)
- ):
+ T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+ T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0,
v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_rms_norm"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(
- A[v_ax0, v_ax1, v_ax2, v_ax3],
- B[v_ax2, v_ax3],
- T_multiply_red[v_ax0, v_ax1],
- )
+ T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3],
T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1])
T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
- T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
- A[v_ax0, v_ax1, v_ax2, v_ax3]
- * B[v_ax2, v_ax3]
- * T.rsqrt(
- T_multiply_red[v_ax0, v_ax1] / (T.float16(4) *
T.float16(5))
- + T.float16(1e-5)
- )
- )
+ T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0,
v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0,
v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_cast_2"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_cast[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float16",
T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
@R.function
- def main(
- x: R.Tensor((2, 3, 4, 5), dtype="float16"),
- weight: R.Tensor((4, 5), dtype="float16"),
- ) -> R.Tensor((2, 3, 4, 5), dtype="float16"):
+ def main(x: R.Tensor((2, 3, 4, 5), dtype="float16"), weight:
R.Tensor((4, 5), dtype="float16")) -> R.Tensor((2, 3, 4, 5), dtype="float16"):
cls = Expected
- gv = R.call_tir(
- cls.rms_norm,
- (x, weight),
- out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float16"),
- )
+ gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2,
3, 4, 5), dtype="float16"))
return gv
# fmt: on
mod = LegalizeOps()(RMSNorm)
@@ -2921,25 +2897,36 @@ def test_rms_norm_symbolic():
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
- def rms_norm(
- var_A: T.handle, var_B: T.handle, var_T_rms_norm: T.handle
- ):
+ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
n, s, f = T.int64(), T.int64(), T.int64()
A = T.match_buffer(var_A, (n, s, f))
B = T.match_buffer(var_B, (s, f))
- T_rms_norm = T.match_buffer(var_T_rms_norm, (n, s, f))
+ T_cast = T.match_buffer(var_T_cast, (n, s, f))
# with T.block("root"):
+ T_cast_1 = T.alloc_buffer((n, s, f))
+ T_cast_2 = T.alloc_buffer((s, f))
T_multiply = T.alloc_buffer((n, s, f))
T_multiply_red = T.alloc_buffer((n,))
+ T_rms_norm = T.alloc_buffer((n, s, f))
for ax0, ax1, ax2 in T.grid(n, s, f):
- with T.block("T_multiply"):
+ with T.block("T_cast"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0, v_ax1, v_ax2])
+ T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
+ T_cast_1[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2]
+ for ax0, ax1 in T.grid(s, f):
+ with T.block("T_cast_1"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(B[v_ax0, v_ax1])
+ T.writes(T_cast_2[v_ax0, v_ax1])
+ T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
+ for ax0, ax1, ax2 in T.grid(n, s, f):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
- T_multiply[v_ax0, v_ax1, v_ax2] = (
- A[v_ax0, v_ax1, v_ax2] * A[v_ax0, v_ax1, v_ax2]
- )
+ T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1,
v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
for ax0, k1, k2 in T.grid(n, s, f):
with T.block("T_multiply_red"):
v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2])
@@ -2947,42 +2934,27 @@ def test_rms_norm_symbolic():
T.writes(T_multiply_red[v_ax0])
with T.init():
T_multiply_red[v_ax0] = T.float32(0)
- T_multiply_red[v_ax0] = (
- T_multiply_red[v_ax0] + T_multiply[v_ax0, v_k1, v_k2]
- )
+ T_multiply_red[v_ax0] = T_multiply_red[v_ax0] +
T_multiply[v_ax0, v_k1, v_k2]
for ax0, ax1, ax2 in T.grid(n, s, f):
with T.block("T_rms_norm"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(
- A[v_ax0, v_ax1, v_ax2],
- B[v_ax1, v_ax2],
- T_multiply_red[v_ax0],
- )
+ T.reads(T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax1,
v_ax2], T_multiply_red[v_ax0])
T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
- T_rms_norm[v_ax0, v_ax1, v_ax2] = (
- A[v_ax0, v_ax1, v_ax2]
- * B[v_ax1, v_ax2]
- * T.rsqrt(
- T_multiply_red[v_ax0]
- / (T.Cast("float32", s) * T.Cast("float32", f))
- + T.float32(1e-5)
- )
- )
+ T_rms_norm[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1,
v_ax2] * T_cast_2[v_ax1, v_ax2] * T.rsqrt(T_multiply_red[v_ax0] /
(T.Cast("float32", s) * T.Cast("float32", f)) +
T.float32(1.0000000000000001e-05))
+ for ax0, ax1, ax2 in T.grid(n, s, f):
+ with T.block("T_cast_2"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
+ T.writes(T_cast[v_ax0, v_ax1, v_ax2])
+ T_cast[v_ax0, v_ax1, v_ax2] = T_rms_norm[v_ax0, v_ax1,
v_ax2]
@R.function
- def main(
- x: R.Tensor(("n", "s", "f"), dtype="float32"),
- weight: R.Tensor(("s", "f"), dtype="float32"),
- ) -> R.Tensor(("n", "s", "f"), dtype="float32"):
+ def main(x: R.Tensor(("n", "s", "f"), dtype="float32"), weight:
R.Tensor(("s", "f"), dtype="float32")) -> R.Tensor(("n", "s", "f"),
dtype="float32"):
n = T.int64()
s = T.int64()
f = T.int64()
cls = Expected
- gv = R.call_tir(
- cls.rms_norm,
- (x, weight),
- out_sinfo=R.Tensor((n, s, f), dtype="float32"),
- )
+ gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((n,
s, f), dtype="float32"))
return gv
# fmt: on
mod = LegalizeOps()(RMSNorm)
@@ -3001,27 +2973,32 @@ def test_rms_norm_no_bias():
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
- def rms_norm(
- A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
"float32"),
- B: T.Buffer((T.int64(4), T.int64(5)), "float32"),
- T_rms_norm: T.Buffer(
- (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"
- ),
- ):
+ def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "float32"), B: T.Buffer((T.int64(4), T.int64(5)), "float32"),
T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
+ T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
- for ax0, ax1, ax2, ax3 in T.grid(
- T.int64(2), T.int64(3), T.int64(4), T.int64(5)
- ):
- with T.block("T_multiply"):
+ T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_cast"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1,
v_ax2, v_ax3]
+ for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+ with T.block("T_cast_1"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(B[v_ax0, v_ax1])
+ T.writes(T_cast_2[v_ax0, v_ax1])
+ T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
- T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
- A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2,
v_ax3]
- )
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0,
v_ax1, v_ax2, v_ax3] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
with T.block("T_multiply_red"):
v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1,
k2, k3])
@@ -3029,38 +3006,24 @@ def test_rms_norm_no_bias():
T.writes(T_multiply_red[v_ax0, v_ax1])
with T.init():
T_multiply_red[v_ax0, v_ax1] = T.float32(0)
- T_multiply_red[v_ax0, v_ax1] = (
- T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0,
v_ax1, v_k2, v_k3]
- )
- for ax0, ax1, ax2, ax3 in T.grid(
- T.int64(2), T.int64(3), T.int64(4), T.int64(5)
- ):
+ T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0,
v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_rms_norm"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(
- A[v_ax0, v_ax1, v_ax2, v_ax3],
- B[v_ax2, v_ax3],
- T_multiply_red[v_ax0, v_ax1],
- )
+ T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3],
T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1])
T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
- T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
- A[v_ax0, v_ax1, v_ax2, v_ax3]
- * B[v_ax2, v_ax3]
- * T.rsqrt(
- T_multiply_red[v_ax0, v_ax1] * T.float32(0.05)
- + T.float32(1e-05)
- )
- )
+ T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0,
v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0,
v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_cast_2"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_cast[v_ax0, v_ax1, v_ax2, v_ax3] = T_rms_norm[v_ax0,
v_ax1, v_ax2, v_ax3]
@R.function
- def main(
- x: R.Tensor((2, 3, 4, 5), dtype="float32"),
- weight: R.Tensor((4, 5), dtype="float32"),
- ) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
+ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32"), weight:
R.Tensor((4, 5), dtype="float32")) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
cls = Expected
- gv = R.call_tir(
- cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5),
dtype="float32")
- )
+ gv = R.call_tir(cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2,
3, 4, 5), dtype="float32"))
return gv
# fmt: on
mod = LegalizeOps()(RMSNorm)