This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 8e54a9e91d [Unity][DLight] Introduce Specific Rule for RMSNorm (#16338)
8e54a9e91d is described below
commit 8e54a9e91d64fd20788117f8556e2d61e7724ff3
Author: Linyu Wu <[email protected]>
AuthorDate: Mon Jan 8 14:19:45 2024 +0800
[Unity][DLight] Introduce Specific Rule for RMSNorm (#16338)
* [Unity][DLight] Introduce Specific Rule for RMSNorm
* fix: remove unused variables
* fix: rename invalid variables
* fix: deal with too general exception
* fix: update tests
* feat: make rule more general
---
include/tvm/topi/nn/rms_norm.h | 24 +-
python/tvm/dlight/gpu/__init__.py | 1 +
python/tvm/dlight/gpu/rmsnorm.py | 140 ++++++++++
tests/python/dlight/test_gpu_rmsnorm.py | 287 +++++++++++++++++++++
.../python/relax/test_transform_legalize_ops_nn.py | 100 ++++---
5 files changed, 513 insertions(+), 39 deletions(-)
diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h
index ba2f7e49ac..7e95000f1e 100644
--- a/include/tvm/topi/nn/rms_norm.h
+++ b/include/tvm/topi/nn/rms_norm.h
@@ -67,6 +67,25 @@ inline Tensor rms_norm(const Tensor& data, const Tensor&
weight, const Array<Int
for (int i : real_axis) {
reduce_extent *= data_fp32->shape[i];
}
+ auto rsqrt_func = [&](const Array<Var>& indices) {
+ Array<Var> non_reduce_indices;
+ for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
+ if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end())
{
+ non_reduce_indices.push_back(indices[i]);
+ }
+ }
+ auto output =
+ tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent +
make_const(data_type, epsilon));
+ return output;
+ };
+ auto rsqrt_shape = Array<PrimExpr>();
+ for (int i = 0, n = static_cast<int>(data_fp32->shape.size()); i < n; ++i) {
+ if (std::find(real_axis.begin(), real_axis.end(), i) == real_axis.end()) {
+ rsqrt_shape.push_back(data_fp32->shape[i]);
+ }
+ }
+ auto rsqrt = tvm::te::compute(rsqrt_shape, rsqrt_func, "rsqrt", tag);
+
auto rms_norm_func = [&](const Array<Var>& indices) {
Array<Var> reduce_indices, non_reduce_indices;
for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
@@ -76,12 +95,11 @@ inline Tensor rms_norm(const Tensor& data, const Tensor&
weight, const Array<Int
non_reduce_indices.push_back(indices[i]);
}
}
- auto output =
- data_fp32(indices) * weight_fp32(reduce_indices) *
- tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent +
make_const(data_type, epsilon));
+ auto output = rsqrt(non_reduce_indices) * data_fp32(indices) *
weight_fp32(reduce_indices);
return output;
};
auto rms_norm = tvm::te::compute(data_fp32->shape, rms_norm_func, name, tag);
+
return cast(rms_norm, data_type);
}
diff --git a/python/tvm/dlight/gpu/__init__.py
b/python/tvm/dlight/gpu/__init__.py
index f48bdb2c81..7db383a161 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -24,3 +24,4 @@ from .matmul import Matmul
from .reduction import Reduction
from .transpose import Transpose
from .general_reduction import GeneralReduction
+from .rmsnorm import RMSNorm
diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py
new file mode 100644
index 0000000000..f8b2bb4a17
--- /dev/null
+++ b/python/tvm/dlight/gpu/rmsnorm.py
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+"""A RMS norm schedule rule for GPU operators."""
+
+import tvm
+from tvm import tir
+from tvm.tir import Block, BufferStore
+from tvm.tir.expr import Cast, BufferLoad, Call
+from tvm.target import Target
+
+from ..base import ScheduleRule
+
+
+def identify_cast_or_load_block(block: Block) -> bool:
+ if len(block.reads) != 1 or len(block.writes) != 1:
+ return False
+
+ if not isinstance(block.body, BufferStore):
+ return False
+ store = block.body
+
+ # check types
+ if isinstance(store.value, BufferLoad):
+ load = store.value
+ elif isinstance(store.value, Cast):
+ load = store.value.value
+ if not isinstance(load, BufferLoad):
+ return False
+ else:
+ return False
+
+ # check indices
+ if len(load.indices) != len(store.indices):
+ return False
+
+ for lhs, rhs in zip(load.indices, store.indices):
+ if not lhs.same_as(rhs):
+ return False
+
+ return True
+
+
+def identify_rsqrt_block(block: Block) -> bool:
+ if len(block.reads) != 1 or len(block.writes) != 1:
+ return False
+
+ if not isinstance(block.body, BufferStore):
+ return False
+ store = block.body
+
+ if not isinstance(store.value, Call):
+ return False
+ call = store.value
+ op = call.op
+
+ return op == tvm.ir.op.Op.get("tir.rsqrt")
+
+
+class RMSNorm(ScheduleRule):
+ """A rule for RMS norm."""
+
+ def apply( # pylint: disable=too-many-locals,missing-docstring
+ self,
+ func: tir.PrimFunc,
+ target: Target,
+ _: bool,
+ ) -> tir.Schedule:
+ if target.kind.name == "cuda":
+ num_tx = 512
+ else:
+ num_tx = 64
+
+ sch = tir.Schedule(func)
+ root = sch.get_block(name="root", func_name="main")
+
+ blocks = sch.get_child_blocks(root)
+
+ if not any([identify_rsqrt_block(sch.get(block)) for block in blocks]):
+ return None
+
+ read = sch.cache_read(block=blocks[0], read_buffer_index=0,
storage_scope="local")
+ write = sch.cache_write(block=blocks[-1], write_buffer_index=0,
storage_scope="local")
+
+ for block in blocks:
+ if identify_cast_or_load_block(sch.get(block)):
+ sch.compute_inline(block)
+
+ blocks = sch.get_child_blocks(root)
+
+ read, sqr, redsum, rsqrt, norm, write = blocks
+
+ if not identify_rsqrt_block(sch.get(rsqrt)):
+ return None
+
+ for name in [read, sqr, redsum, rsqrt, norm, write]:
+ loops = sch.get_loops(name)
+ sch.fuse(*loops[:-1])
+
+ block_loop, loops = sch.get_loops(block=read)
+ thread_loop, _, _ = sch.split(
+ loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True
+ )
+ sch.bind(block_loop, thread_axis="blockIdx.x")
+ sch.bind(thread_loop, thread_axis="threadIdx.x")
+ sch.vectorize(sch.get_loops(block=read)[-1])
+ sch.reverse_compute_at(block=sqr, loop=thread_loop)
+ sch.reverse_compute_at(block=redsum, loop=thread_loop)
+
+ sch.reverse_compute_at(block=rsqrt, loop=block_loop, index=-1)
+ sch.reverse_compute_at(block=norm, loop=block_loop, index=-1)
+ block_loop, loops = sch.get_loops(block=norm)
+ thread_loop, _, _ = sch.split(
+ loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True
+ )
+ sch.bind(thread_loop, thread_axis="threadIdx.x")
+
+ sch.reverse_compute_at(block=write, loop=thread_loop, index=-1)
+ sch.vectorize(sch.get_loops(block=write)[-1])
+
+ sch.set_scope(block=sqr, buffer_index=0, storage_scope="local")
+ sch.set_scope(block=redsum, buffer_index=0, storage_scope="local")
+ sch.set_scope(block=rsqrt, buffer_index=0, storage_scope="shared")
+ sch.set_scope(block=norm, buffer_index=0, storage_scope="local")
+
+ return sch
diff --git a/tests/python/dlight/test_gpu_rmsnorm.py
b/tests/python/dlight/test_gpu_rmsnorm.py
new file mode 100644
index 0000000000..301dac5c66
--- /dev/null
+++ b/tests/python/dlight/test_gpu_rmsnorm.py
@@ -0,0 +1,287 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import tvm.testing
+
+from tvm.ir import IRModule, assert_structural_equal
+from tvm import dlight as dl
+from tvm.script import ir as I
+from tvm.target import Target
+from tvm.script import tir as T
+
+
+def _check(mod_before: IRModule, mod_after: IRModule):
+ target = Target("nvidia/geforce-rtx-3090-ti")
+ with target:
+ mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable
+ dl.gpu.RMSNorm(),
+ )(mod_before)
+ assert_structural_equal(mod, mod_after)
+
+
+def test_rms_norm_with_casting():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"),
var_T_cast: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n = T.int32()
+ data = T.match_buffer(var_data, (1, n, 4096), "float16")
+ T_cast = T.match_buffer(var_T_cast, (1, n, 4096), "float16")
+ # with T.block("root"):
+ T_cast_1 = T.alloc_buffer((1, n, 4096))
+ T_multiply = T.alloc_buffer((1, n, 4096))
+ T_multiply_red = T.alloc_buffer((1, n))
+ rsqrt = T.alloc_buffer((1, n))
+ T_cast_2 = T.alloc_buffer((4096,))
+ T_rms_norm = T.alloc_buffer((1, n, 4096))
+ for ax0, ax1, ax2 in T.grid(1, n, 4096):
+ with T.block("T_cast"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(data[v_ax0, v_ax1, v_ax2])
+ T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
+ T_cast_1[v_ax0, v_ax1, v_ax2] = T.Cast("float32",
data[v_ax0, v_ax1, v_ax2])
+ for ax0, ax1, ax2 in T.grid(1, n, 4096):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_cast_1[v_ax0, v_ax1, v_ax2])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
+ T_multiply[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1,
v_ax2] * T_cast_1[v_ax0, v_ax1, v_ax2]
+ for ax0, ax1, k2 in T.grid(1, n, 4096):
+ with T.block("T_multiply_red"):
+ v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
+ T.reads(T_multiply[v_ax0, v_ax1, v_k2])
+ T.writes(T_multiply_red[v_ax0, v_ax1])
+ with T.init():
+ T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+ T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0,
v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
+ for ax0, ax1 in T.grid(1, n):
+ with T.block("rsqrt"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(T_multiply_red[v_ax0, v_ax1])
+ T.writes(rsqrt[v_ax0, v_ax1])
+ rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1]
* T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))
+ for ax0 in range(4096):
+ with T.block("T_cast_1"):
+ v_ax0 = T.axis.spatial(4096, ax0)
+ T.reads(weight[v_ax0])
+ T.writes(T_cast_2[v_ax0])
+ T_cast_2[v_ax0] = T.Cast("float32", weight[v_ax0])
+ for ax0, ax1, ax2 in T.grid(1, n, 4096):
+ with T.block("T_rms_norm"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1,
v_ax2], T_cast_2[v_ax2])
+ T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
+ T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] *
T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax2]
+ for ax0, ax1, ax2 in T.grid(1, n, 4096):
+ with T.block("T_cast_2"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
+ T.writes(T_cast[v_ax0, v_ax1, v_ax2])
+ T_cast[v_ax0, v_ax1, v_ax2] = T.Cast("float16",
T_rms_norm[v_ax0, v_ax1, v_ax2])
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"),
var_T_cast: T.handle):
+ T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ n = T.int32()
+ data = T.match_buffer(var_data, (1, n, 4096), "float16")
+ T_cast = T.match_buffer(var_T_cast, (1, n, 4096), "float16")
+ # with T.block("root"):
+ T_multiply_local = T.alloc_buffer((1, n, 4096), scope="local")
+ T_multiply_red_local = T.alloc_buffer((1, n), scope="local")
+ rsqrt_shared = T.alloc_buffer((1, n), scope="shared")
+ T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local")
+ data_local = T.alloc_buffer((1, n, 4096), "float16", scope="local")
+ for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x"):
+ for ax2_0 in T.thread_binding(512, thread="threadIdx.x"):
+ for ax2_1 in range(1):
+ for ax2_2 in T.vectorized(8):
+ with T.block("data_local"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(n, ax0_ax1_fused)
+ v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 *
8 + ax2_2)
+ T.reads(data[v0, v1, v2])
+ T.writes(data_local[v0, v1, v2])
+ data_local[v0, v1, v2] = data[v0, v1, v2]
+ for ax0 in range(8):
+ with T.block("T_multiply"):
+ v_ax0 = T.axis.spatial(1, 0)
+ v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+ v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0)
+ T.reads(data_local[v_ax0, v_ax1, v_ax2])
+ T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2])
+ T_multiply_local[v_ax0, v_ax1, v_ax2] =
T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2]) * T.Cast("float32",
data_local[v_ax0, v_ax1, v_ax2])
+ for ax0 in range(8):
+ with T.block("T_multiply_red"):
+ v_ax0 = T.axis.spatial(1, 0)
+ v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+ v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0)
+ T.reads(T_multiply_local[v_ax0, v_ax1, v_k2])
+ T.writes(T_multiply_red_local[v_ax0, v_ax1])
+ with T.init():
+ T_multiply_red_local[v_ax0, v_ax1] =
T.float32(0)
+ T_multiply_red_local[v_ax0, v_ax1] =
T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2]
+ with T.block("rsqrt"):
+ v_ax0 = T.axis.spatial(1, 0)
+ v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+ T.reads(T_multiply_red_local[v_ax0, v_ax1])
+ T.writes(rsqrt_shared[v_ax0, v_ax1])
+ rsqrt_shared[v_ax0, v_ax1] =
T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) +
T.float32(9.9999999999999995e-07))
+ for ax0_0 in T.thread_binding(512, thread="threadIdx.x"):
+ for ax0_1, ax0_2 in T.grid(1, 8):
+ with T.block("T_rms_norm"):
+ v_ax0 = T.axis.spatial(1, 0)
+ v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+ v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8
+ ax0_2)
+ T.reads(rsqrt_shared[v_ax0, v_ax1],
data_local[v_ax0, v_ax1, v_ax2], weight[v_ax2])
+ T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2])
+ T_rms_norm_local[v_ax0, v_ax1, v_ax2] =
rsqrt_shared[v_ax0, v_ax1] * T.Cast("float32", data_local[v_ax0, v_ax1, v_ax2])
* T.Cast("float32", weight[v_ax2])
+ for ax0 in T.vectorized(8):
+ with T.block("T_cast_local"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(n, ax0_ax1_fused)
+ v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0)
+ T.reads(T_rms_norm_local[v0, v1, v2])
+ T.writes(T_cast[v0, v1, v2])
+ T_cast[v0, v1, v2] = T.Cast("float16",
T_rms_norm_local[v0, v1, v2])
+ # fmt: on
+ _check(Before, After)
+
+
+def test_rms_norm_without_casting():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"),
var_T_cast: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n = T.int32()
+ data = T.match_buffer(var_data, (1, n, 4096))
+ T_cast = T.match_buffer(var_T_cast, (1, n, 4096))
+ # with T.block("root"):
+ T_multiply = T.alloc_buffer((1, n, 4096))
+ T_multiply_red = T.alloc_buffer((1, n))
+ rsqrt = T.alloc_buffer((1, n))
+ T_rms_norm = T.alloc_buffer((1, n, 4096))
+ for ax0, ax1, ax2 in T.grid(1, n, 4096):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(data[v_ax0, v_ax1, v_ax2])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
+ T_multiply[v_ax0, v_ax1, v_ax2] = data[v_ax0, v_ax1,
v_ax2] * data[v_ax0, v_ax1, v_ax2]
+ for ax0, ax1, k2 in T.grid(1, n, 4096):
+ with T.block("T_multiply_red"):
+ v_ax0, v_ax1, v_k2 = T.axis.remap("SSR", [ax0, ax1, k2])
+ T.reads(T_multiply[v_ax0, v_ax1, v_k2])
+ T.writes(T_multiply_red[v_ax0, v_ax1])
+ with T.init():
+ T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+ T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0,
v_ax1] + T_multiply[v_ax0, v_ax1, v_k2]
+ for ax0, ax1 in T.grid(1, n):
+ with T.block("rsqrt"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(T_multiply_red[v_ax0, v_ax1])
+ T.writes(rsqrt[v_ax0, v_ax1])
+ rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1]
* T.float32(0.000244140625) + T.float32(9.9999999999999995e-07))
+ for ax0, ax1, ax2 in T.grid(1, n, 4096):
+ with T.block("T_rms_norm"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(rsqrt[v_ax0, v_ax1], data[v_ax0, v_ax1, v_ax2],
weight[v_ax2])
+ T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
+ T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0, v_ax1] *
data[v_ax0, v_ax1, v_ax2] * weight[v_ax2]
+ for ax0, ax1, ax2 in T.grid(1, n, 4096):
+ with T.block("T_cast_2"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2])
+ T.writes(T_cast[v_ax0, v_ax1, v_ax2])
+ T_cast[v_ax0, v_ax1, v_ax2] = T_rms_norm[v_ax0, v_ax1,
v_ax2]
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"),
var_T_cast: T.handle):
+ T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ n = T.int32()
+ data = T.match_buffer(var_data, (1, n, 4096))
+ T_cast = T.match_buffer(var_T_cast, (1, n, 4096))
+ # with T.block("root"):
+ T_multiply_local = T.alloc_buffer((1, n, 4096), scope="local")
+ T_multiply_red_local = T.alloc_buffer((1, n), scope="local")
+ rsqrt_shared = T.alloc_buffer((1, n), scope="shared")
+ T_rms_norm_local = T.alloc_buffer((1, n, 4096), scope="local")
+ data_local = T.alloc_buffer((1, n, 4096), scope="local")
+ for ax0_ax1_fused in T.thread_binding(n, thread="blockIdx.x"):
+ for ax2_0 in T.thread_binding(512, thread="threadIdx.x"):
+ for ax2_1 in range(1):
+ for ax2_2 in T.vectorized(8):
+ with T.block("data_local"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(n, ax0_ax1_fused)
+ v2 = T.axis.spatial(4096, ax2_0 * 8 + ax2_1 *
8 + ax2_2)
+ T.reads(data[v0, v1, v2])
+ T.writes(data_local[v0, v1, v2])
+ data_local[v0, v1, v2] = data[v0, v1, v2]
+ for ax0 in range(8):
+ with T.block("T_multiply"):
+ v_ax0 = T.axis.spatial(1, 0)
+ v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+ v_ax2 = T.axis.spatial(4096, ax2_0 * 8 + ax0)
+ T.reads(data_local[v_ax0, v_ax1, v_ax2])
+ T.writes(T_multiply_local[v_ax0, v_ax1, v_ax2])
+ T_multiply_local[v_ax0, v_ax1, v_ax2] =
data_local[v_ax0, v_ax1, v_ax2] * data_local[v_ax0, v_ax1, v_ax2]
+ for ax0 in range(8):
+ with T.block("T_multiply_red"):
+ v_ax0 = T.axis.spatial(1, 0)
+ v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+ v_k2 = T.axis.reduce(4096, ax2_0 * 8 + ax0)
+ T.reads(T_multiply_local[v_ax0, v_ax1, v_k2])
+ T.writes(T_multiply_red_local[v_ax0, v_ax1])
+ with T.init():
+ T_multiply_red_local[v_ax0, v_ax1] =
T.float32(0)
+ T_multiply_red_local[v_ax0, v_ax1] =
T_multiply_red_local[v_ax0, v_ax1] + T_multiply_local[v_ax0, v_ax1, v_k2]
+ with T.block("rsqrt"):
+ v_ax0 = T.axis.spatial(1, 0)
+ v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+ T.reads(T_multiply_red_local[v_ax0, v_ax1])
+ T.writes(rsqrt_shared[v_ax0, v_ax1])
+ rsqrt_shared[v_ax0, v_ax1] =
T.rsqrt(T_multiply_red_local[v_ax0, v_ax1] * T.float32(0.000244140625) +
T.float32(9.9999999999999995e-07))
+ for ax0_0 in T.thread_binding(512, thread="threadIdx.x"):
+ for ax0_1, ax0_2 in T.grid(1, 8):
+ with T.block("T_rms_norm"):
+ v_ax0 = T.axis.spatial(1, 0)
+ v_ax1 = T.axis.spatial(n, ax0_ax1_fused)
+ v_ax2 = T.axis.spatial(4096, ax0_0 * 8 + ax0_1 * 8
+ ax0_2)
+ T.reads(rsqrt_shared[v_ax0, v_ax1],
data_local[v_ax0, v_ax1, v_ax2], weight[v_ax2])
+ T.writes(T_rms_norm_local[v_ax0, v_ax1, v_ax2])
+ T_rms_norm_local[v_ax0, v_ax1, v_ax2] =
rsqrt_shared[v_ax0, v_ax1] * data_local[v_ax0, v_ax1, v_ax2] * weight[v_ax2]
+ for ax0 in T.vectorized(8):
+ with T.block("T_cast_local"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(n, ax0_ax1_fused)
+ v2 = T.axis.spatial(4096, ax0_0 * 8 + ax0)
+ T.reads(T_rms_norm_local[v0, v1, v2])
+ T.writes(T_cast[v0, v1, v2])
+ T_cast[v0, v1, v2] = T_rms_norm_local[v0, v1, v2]
+ # fmt: on
+ _check(Before, After)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 74da77f7d8..07fbc3419b 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2773,9 +2773,10 @@ def test_rms_norm():
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
- T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+ rsqrt = T.alloc_buffer((T.int64(2), T.int64(3)))
+ T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_cast"):
@@ -2783,12 +2784,6 @@ def test_rms_norm():
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1,
v_ax2, v_ax3]
- for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
- with T.block("T_cast_1"):
- v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(B[v_ax0, v_ax1])
- T.writes(T_cast_2[v_ax0, v_ax1])
- T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
@@ -2803,12 +2798,24 @@ def test_rms_norm():
with T.init():
T_multiply_red[v_ax0, v_ax1] = T.float32(0)
T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0,
v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("rsqrt"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(T_multiply_red[v_ax0, v_ax1])
+ T.writes(rsqrt[v_ax0, v_ax1])
+ rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1]
* T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+ for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+ with T.block("T_cast_1"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(B[v_ax0, v_ax1])
+ T.writes(T_cast_2[v_ax0, v_ax1])
+ T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_rms_norm"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3],
T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1])
+ T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2,
v_ax3], T_cast_2[v_ax2, v_ax3])
T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
- T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0,
v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0,
v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+ T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0,
v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_cast_2"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
@@ -2842,9 +2849,10 @@ def test_rms_norm_fp16():
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
- T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+ rsqrt = T.alloc_buffer((T.int64(2), T.int64(3)))
+ T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_cast"):
@@ -2852,12 +2860,6 @@ def test_rms_norm_fp16():
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast("float32",
A[v_ax0, v_ax1, v_ax2, v_ax3])
- for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
- with T.block("T_cast_1"):
- v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(B[v_ax0, v_ax1])
- T.writes(T_cast_2[v_ax0, v_ax1])
- T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1])
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
@@ -2872,12 +2874,24 @@ def test_rms_norm_fp16():
with T.init():
T_multiply_red[v_ax0, v_ax1] = T.float32(0)
T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0,
v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("rsqrt"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(T_multiply_red[v_ax0, v_ax1])
+ T.writes(rsqrt[v_ax0, v_ax1])
+ rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1]
* T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+ for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+ with T.block("T_cast_1"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(B[v_ax0, v_ax1])
+ T.writes(T_cast_2[v_ax0, v_ax1])
+ T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1])
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_rms_norm"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3],
T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1])
+ T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2,
v_ax3], T_cast_2[v_ax2, v_ax3])
T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
- T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0,
v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0,
v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+ T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0,
v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_cast_2"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
@@ -2918,9 +2932,10 @@ def test_rms_norm_symbolic():
T_cast = T.match_buffer(var_T_cast, (n, s, f))
# with T.block("root"):
T_cast_1 = T.alloc_buffer((n, s, f))
- T_cast_2 = T.alloc_buffer((s, f))
T_multiply = T.alloc_buffer((n, s, f))
T_multiply_red = T.alloc_buffer((n,))
+ rsqrt = T.alloc_buffer((n,))
+ T_cast_2 = T.alloc_buffer((s, f))
T_rms_norm = T.alloc_buffer((n, s, f))
for ax0, ax1, ax2 in T.grid(n, s, f):
with T.block("T_cast"):
@@ -2928,12 +2943,6 @@ def test_rms_norm_symbolic():
T.reads(A[v_ax0, v_ax1, v_ax2])
T.writes(T_cast_1[v_ax0, v_ax1, v_ax2])
T_cast_1[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2]
- for ax0, ax1 in T.grid(s, f):
- with T.block("T_cast_1"):
- v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(B[v_ax0, v_ax1])
- T.writes(T_cast_2[v_ax0, v_ax1])
- T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
for ax0, ax1, ax2 in T.grid(n, s, f):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
@@ -2948,12 +2957,24 @@ def test_rms_norm_symbolic():
with T.init():
T_multiply_red[v_ax0] = T.float32(0)
T_multiply_red[v_ax0] = T_multiply_red[v_ax0] +
T_multiply[v_ax0, v_k1, v_k2]
+ for ax0 in range(n):
+ with T.block("rsqrt"):
+ v_ax0 = T.axis.spatial(n, ax0)
+ T.reads(T_multiply_red[v_ax0])
+ T.writes(rsqrt[v_ax0])
+ rsqrt[v_ax0] = T.rsqrt(T_multiply_red[v_ax0] /
(T.Cast("float32", s) * T.Cast("float32", f)) +
T.float32(1.0000000000000001e-05))
+ for ax0, ax1 in T.grid(s, f):
+ with T.block("T_cast_1"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(B[v_ax0, v_ax1])
+ T.writes(T_cast_2[v_ax0, v_ax1])
+ T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
for ax0, ax1, ax2 in T.grid(n, s, f):
with T.block("T_rms_norm"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
- T.reads(T_cast_1[v_ax0, v_ax1, v_ax2], T_cast_2[v_ax1,
v_ax2], T_multiply_red[v_ax0])
+ T.reads(rsqrt[v_ax0], T_cast_1[v_ax0, v_ax1, v_ax2],
T_cast_2[v_ax1, v_ax2])
T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
- T_rms_norm[v_ax0, v_ax1, v_ax2] = T_cast_1[v_ax0, v_ax1,
v_ax2] * T_cast_2[v_ax1, v_ax2] * T.rsqrt(T_multiply_red[v_ax0] /
(T.Cast("float32", s) * T.Cast("float32", f)) +
T.float32(1.0000000000000001e-05))
+ T_rms_norm[v_ax0, v_ax1, v_ax2] = rsqrt[v_ax0] *
T_cast_1[v_ax0, v_ax1, v_ax2] * T_cast_2[v_ax1, v_ax2]
for ax0, ax1, ax2 in T.grid(n, s, f):
with T.block("T_cast_2"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
@@ -2990,9 +3011,10 @@ def test_rms_norm_no_bias():
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
- T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+ rsqrt = T.alloc_buffer((T.int64(2), T.int64(3)))
+ T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_cast"):
@@ -3000,12 +3022,6 @@ def test_rms_norm_no_bias():
T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = A[v_ax0, v_ax1,
v_ax2, v_ax3]
- for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
- with T.block("T_cast_1"):
- v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(B[v_ax0, v_ax1])
- T.writes(T_cast_2[v_ax0, v_ax1])
- T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_multiply"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
@@ -3020,12 +3036,24 @@ def test_rms_norm_no_bias():
with T.init():
T_multiply_red[v_ax0, v_ax1] = T.float32(0)
T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0,
v_ax1] + T_multiply[v_ax0, v_ax1, v_k2, v_k3]
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("rsqrt"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(T_multiply_red[v_ax0, v_ax1])
+ T.writes(rsqrt[v_ax0, v_ax1])
+ rsqrt[v_ax0, v_ax1] = T.rsqrt(T_multiply_red[v_ax0, v_ax1]
* T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+ for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+ with T.block("T_cast_1"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(B[v_ax0, v_ax1])
+ T.writes(T_cast_2[v_ax0, v_ax1])
+ T_cast_2[v_ax0, v_ax1] = B[v_ax0, v_ax1]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_rms_norm"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3],
T_cast_2[v_ax2, v_ax3], T_multiply_red[v_ax0, v_ax1])
+ T.reads(rsqrt[v_ax0, v_ax1], T_cast_1[v_ax0, v_ax1, v_ax2,
v_ax3], T_cast_2[v_ax2, v_ax3])
T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
- T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = T_cast_1[v_ax0,
v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3] * T.rsqrt(T_multiply_red[v_ax0,
v_ax1] * T.float32(0.050000000000000003) + T.float32(1.0000000000000001e-05))
+ T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = rsqrt[v_ax0,
v_ax1] * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_cast_2[v_ax2, v_ax3]
for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_cast_2"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])