This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 92668c7d05 [Unity][Dlight] Fix reduction rule, aligning last block's
iters (#15383)
92668c7d05 is described below
commit 92668c7d0599b20d4da57d6cd4f1c561149e1472
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jul 24 21:40:56 2023 -0700
[Unity][Dlight] Fix reduction rule, aligning last block's iters (#15383)
[Unity][Dlight] Fix reduction rule, alining last block's iters
Prior to this PR, the reduction dlight rule sometimes may
over-normalizes the last block and removes some spatial loops, making
the last block have fewer block iters than previous reduction blocks.
In this case, the dlight rule is not able to apply due to assertion
failure while it is supposed to apply.
This PR thus adds unit block iter to the last block when it has fewer
iters than other reduction blocks in the front, so that the numbers of
block iters can align.
One softmax unit test is added to ensure that dlight reduction rule
can work.
---
python/tvm/dlight/gpu/general_reduction.py | 15 +++-
tests/python/dlight/test_gpu_general_reduction.py | 86 ++++++++++++++++++++++-
2 files changed, 98 insertions(+), 3 deletions(-)
diff --git a/python/tvm/dlight/gpu/general_reduction.py
b/python/tvm/dlight/gpu/general_reduction.py
index 097cd59d3a..31aa086cac 100644
--- a/python/tvm/dlight/gpu/general_reduction.py
+++ b/python/tvm/dlight/gpu/general_reduction.py
@@ -51,13 +51,26 @@ class GeneralReduction(ScheduleRule):
dom_kind = block_infos[0].dom_kind()
num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S"))
num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R"))
+
+ # Align the number of block iters of the last block.
+ num_last_block_iter = len(block_infos[-1].dom_kind())
+ if num_last_block_iter < len(dom_kind):
+ index_map = tir.IndexMap.from_func(
+ lambda *iters: (
+ [tir.const(0, iters[0].dtype)] * (len(dom_kind) -
num_last_block_iter)
+ + list(iters)
+ ),
+ ndim=num_last_block_iter,
+ )
+ sch.transform_block_layout(block_infos[-1].block_rv, index_map)
+
try:
# TODO: fix num_leading_s = 0 case
assert num_trailing_r > 0
for block in block_infos[1:-1]:
assert block.dom_kind() == dom_kind
assert block_infos[-1].is_injective()
- assert len(block_infos[-1].dom_kind()) == len(dom_kind)
+ assert len(block_infos[-1].dom_kind()) <= len(dom_kind)
except AssertionError:
return None
diff --git a/tests/python/dlight/test_gpu_general_reduction.py
b/tests/python/dlight/test_gpu_general_reduction.py
index dfabcd14d9..621449dcd8 100644
--- a/tests/python/dlight/test_gpu_general_reduction.py
+++ b/tests/python/dlight/test_gpu_general_reduction.py
@@ -31,7 +31,7 @@ def _check(mod_before: IRModule, mod_after: IRModule):
assert_structural_equal(mod, mod_after)
-def test_softmax():
+def test_softmax_1():
# fmt: off
@I.ir_module
class Before:
@@ -132,6 +132,87 @@ def test_softmax():
_check(Before, After)
+def test_softmax_2():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)),
"float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)),
"float32")):
+ # with T.block("root"):
+ T_softmax_maxelem = T.alloc_buffer((T.int64(1), T.int64(1)))
+ T_softmax_exp = T.alloc_buffer((T.int64(1), T.int64(1),
T.int64(32000)))
+ T_softmax_expsum = T.alloc_buffer((T.int64(1), T.int64(1)))
+ for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
+ with T.block("T_softmax_maxelem"):
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+ T.reads(A[v_i0, v_i1, v_k])
+ T.writes(T_softmax_maxelem[v_i0, v_i1])
+ with T.init():
+ T_softmax_maxelem[v_i0, v_i1] =
T.float32(-3.4028234663852886e+38)
+ T_softmax_maxelem[v_i0, v_i1] =
T.max(T_softmax_maxelem[v_i0, v_i1], A[v_i0, v_i1, v_k])
+ for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
+ with T.block("T_softmax_exp"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(A[v_i0, v_i1, v_i2], T_softmax_maxelem[v_i0, v_i1])
+ T.writes(T_softmax_exp[v_i0, v_i1, v_i2])
+ T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(A[v_i0, v_i1,
v_i2] - T_softmax_maxelem[v_i0, v_i1])
+ for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
+ with T.block("T_softmax_expsum"):
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+ T.reads(T_softmax_exp[v_i0, v_i1, v_k])
+ T.writes(T_softmax_expsum[v_i0, v_i1])
+ with T.init():
+ T_softmax_expsum[v_i0, v_i1] = T.float32(0)
+ T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0,
v_i1] + T_softmax_exp[v_i0, v_i1, v_k]
+ for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(32000)):
+ with T.block("T_softmax_norm"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(T_softmax_exp[v_i0, v_i1, v_i2],
T_softmax_expsum[v_i0, v_i1])
+ T.writes(T_softmax_norm[v_i0, v_i1, v_i2])
+ T.block_attr({"axis": 2})
+ T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0,
v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1]
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)),
"float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)),
"float32")):
+ T.func_attr({"tir.is_scheduled": 1})
+ T_softmax_maxelem_shared = T.alloc_buffer((T.int64(1),
T.int64(1)), scope="shared")
+ T_softmax_expsum_shared = T.alloc_buffer((T.int64(1), T.int64(1)),
scope="shared")
+ for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax0, ax1_fused_0 in T.grid(T.int64(1), T.int64(125)):
+ for ax1_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ with T.block("T_softmax_maxelem"):
+ v0 = T.axis.spatial(T.int64(1), ax0)
+ v1 = T.axis.reduce(T.int64(32000), ax1_fused_0 *
T.int64(256) + ax1_fused_1)
+ T.reads(A[T.int64(0), T.int64(0), v1])
+ T.writes(T_softmax_maxelem_shared[T.int64(0),
T.int64(0)])
+ with T.init():
+ T_softmax_maxelem_shared[T.int64(0),
T.int64(0)] = T.float32(-3.4028234663852886e+38)
+ T_softmax_maxelem_shared[T.int64(0), T.int64(0)] =
T.max(T_softmax_maxelem_shared[T.int64(0), T.int64(0)], A[T.int64(0),
T.int64(0), v1])
+ for ax0, ax1_fused_0 in T.grid(T.int64(1), T.int64(125)):
+ for ax1_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ with T.block("T_softmax_expsum"):
+ v0 = T.axis.spatial(T.int64(1), ax0)
+ v1 = T.axis.reduce(T.int64(32000), ax1_fused_0 *
T.int64(256) + ax1_fused_1)
+ T.reads(A[T.int64(0), T.int64(0), v1],
T_softmax_maxelem_shared[T.int64(0), T.int64(0)])
+ T.writes(T_softmax_expsum_shared[T.int64(0),
T.int64(0)])
+ with T.init():
+ T_softmax_expsum_shared[T.int64(0),
T.int64(0)] = T.float32(0)
+ T_softmax_expsum_shared[T.int64(0), T.int64(0)] =
T_softmax_expsum_shared[T.int64(0), T.int64(0)] + T.exp(A[T.int64(0),
T.int64(0), v1] - T_softmax_maxelem_shared[T.int64(0), T.int64(0)])
+ for ax1_0 in range(T.int64(125)):
+ for ax1_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ with T.block("T_softmax_norm"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial(T.int64(32000), ax1_0 *
T.int64(256) + ax1_1)
+ T.reads(A[T.int64(0), T.int64(0), v1],
T_softmax_maxelem_shared[T.int64(0), T.int64(0)],
T_softmax_expsum_shared[T.int64(0), T.int64(0)])
+ T.writes(T_softmax_norm[T.int64(0), T.int64(0),
v1])
+ T.block_attr({"axis": 2})
+ T_softmax_norm[T.int64(0), T.int64(0), v1] =
T.exp(A[T.int64(0), T.int64(0), v1] - T_softmax_maxelem_shared[T.int64(0),
T.int64(0)]) / T_softmax_expsum_shared[T.int64(0), T.int64(0)]
+ # fmt: on
+ _check(Before, After)
+
+
def test_layer_norm():
# fmt: off
@I.ir_module
@@ -270,6 +351,7 @@ def test_rms_norm():
if __name__ == "__main__":
- test_softmax()
+ test_softmax_1()
+ test_softmax_2()
test_layer_norm()
test_rms_norm()