This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new b00ed76318 [Unity][Dlight] Minor performance improvement for gemm and
gemv (#15278)
b00ed76318 is described below
commit b00ed763183e7c4ad2f9d97ab1d2b9291b838bb7
Author: Siyuan Feng <[email protected]>
AuthorDate: Mon Jul 10 22:00:05 2023 +0800
[Unity][Dlight] Minor performance improvement for gemm and gemv (#15278)
This PR is to improve the performance of gemm and gemv in dlight:
- gemm: use vectorized load/store and change bindings.
- gemv: apply unroll
---
python/tvm/dlight/gpu/decode_gemv.py | 2 +
python/tvm/dlight/gpu/matmul.py | 11 ++-
tests/python/dlight/test_gpu_decode_gemv.py | 10 +--
tests/python/dlight/test_gpu_matmul.py | 112 ++++++++++++++--------------
4 files changed, 70 insertions(+), 65 deletions(-)
diff --git a/python/tvm/dlight/gpu/decode_gemv.py
b/python/tvm/dlight/gpu/decode_gemv.py
index d0d37e8476..6c7e31181b 100644
--- a/python/tvm/dlight/gpu/decode_gemv.py
+++ b/python/tvm/dlight/gpu/decode_gemv.py
@@ -206,6 +206,8 @@ class DecodeGEMV(ScheduleRule):
sch.reorder(bx, tx, r)
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
+ sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=256)
+ sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1)
sch.set_scope(rf, 0, "local")
sch.decompose_reduction(rf, r)
# Schedule the write back block
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index 86d685e53c..be5e4b02d7 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -297,7 +297,7 @@ class Matmul(ScheduleRule):
block_size_y = 16
vthread_x = 1
vthread_y = 1
- micro_size_x = 2
+ micro_size_x = 4
micro_size_y = 4
micro_size_k = 16
vector_size = 2
@@ -340,20 +340,23 @@ class Matmul(ScheduleRule):
l2g = sch.cache_write(main_block, 0, "local")
sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True)
+ if micro_size_y % vector_size == 0:
+ _, v = sch.split(sch.get_loops(l2g)[-1], [None, vector_size])
+ sch.vectorize(v)
def _cooperative_fetch(index, vec_len):
block = sch.cache_read(main_block, index, "shared")
num_loops = len(sch.get_loops(block))
sch.compute_at(block, ko, preserve_unit_loops=True)
loops = sch.get_loops(block)[-num_loops:]
- _, ty, tx, vec = sch.split(
+ ty, tx, _, vec = sch.split(
sch.fuse(*loops),
- factors=[None, block_size_y, block_size_x, vec_len],
+ factors=[block_size_y, block_size_x, None, vec_len],
)
sch.vectorize(vec)
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
- sch.storage_align(block, 0, axis=1, factor=32, offset=vec_len)
+ sch.storage_align(block, 0, axis=1, factor=8, offset=vec_len)
return block
a_g2s = _cooperative_fetch(0, vec_len=vector_size)
diff --git a/tests/python/dlight/test_gpu_decode_gemv.py
b/tests/python/dlight/test_gpu_decode_gemv.py
index bd84aeb096..7b19e6b7f8 100644
--- a/tests/python/dlight/test_gpu_decode_gemv.py
+++ b/tests/python/dlight/test_gpu_decode_gemv.py
@@ -56,7 +56,7 @@ def test_decode_gemv_1():
# with T.block("root"):
C_rf_local = T.alloc_buffer((256, 1, 1, 4096), "float16",
scope="local")
for i2_i0_i1_fused in T.thread_binding(4096, thread="blockIdx.x"):
- for k_0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+ for k_0_fused_1 in T.thread_binding(256, thread="threadIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
with T.block("matmul_rf_init"):
vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
@@ -179,7 +179,7 @@ def test_decode_gemv_3():
# with T.block("root"):
C_rf_local = T.alloc_buffer((256, 1, 1, 4096), "float16",
scope="local")
for i2_0_i0_i1_fused in T.thread_binding(512, thread="blockIdx.x"):
- for k_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+ for k_fused_1 in T.thread_binding(256, thread="threadIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for i2_1_init in range(8):
with T.block("matmul_rf_init"):
vk_fused_1 = T.axis.spatial(256, k_fused_1)
@@ -315,7 +315,7 @@ def test_decode_gemv_sigmoid():
C_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local")
C_rf_local = T.alloc_buffer((256, 1, 1, 4096), "float16",
scope="local")
for i2_i0_i1_fused in T.thread_binding(4096, thread="blockIdx.x"):
- for k_0_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+ for k_0_fused_1 in T.thread_binding(256, thread="threadIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
with T.block("matmul_rf_init"):
vk_0_fused_1 = T.axis.spatial(256, k_0_fused_1)
v_i2 = T.axis.spatial(4096, i2_i0_i1_fused)
@@ -387,7 +387,7 @@ def test_decode_gemv_1_fp32():
C_fp32_local = T.alloc_buffer((1, 1, 4096), scope="local")
C_fp32_rf_local = T.alloc_buffer((256, 1, 1, 4096), scope="local")
for ax0_fused in T.thread_binding(4096, thread="blockIdx.x"):
- for ax1_0_fused_1 in T.thread_binding(256,
thread="threadIdx.x"):
+ for ax1_0_fused_1 in T.thread_binding(256,
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
with T.block("matmul_rf_init"):
vax1_0_fused_1, v0 = T.axis.remap("SS",
[ax1_0_fused_1, ax0_fused])
T.reads()
@@ -450,7 +450,7 @@ def test_reduction_no_spatial():
Ared_temp_shared = T.alloc_buffer((1, 1), scope="shared")
Ared_temp_rf_local = T.alloc_buffer((256, 1, 1), scope="local")
for ax0_fused in T.thread_binding(T.int64(1),
thread="blockIdx.x"): # pylint: disable=unused-variable
- for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x"):
+ for ax1_fused_1 in T.thread_binding(256, thread="threadIdx.x",
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
with T.block("Ared_temp_rf_init"):
vax1_fused_1 = T.axis.spatial(256, ax1_fused_1)
v0 = T.axis.spatial(T.int64(1), T.int64(0))
diff --git a/tests/python/dlight/test_gpu_matmul.py
b/tests/python/dlight/test_gpu_matmul.py
index 38ecbfee94..f3d9a7089d 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -19,8 +19,6 @@ import pytest
import tvm.testing
from tvm import dlight as dl
-from tvm.ir import assert_structural_equal
-from tvm.script import ir as I
from tvm.script import tir as T
from tvm.target import Target
@@ -56,67 +54,68 @@ class TestMatmul(BaseBeforeAfter):
inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
# with T.block("root"):
- matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m +
T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="local")
- inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m +
T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="shared")
+ matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m +
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
+ inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m +
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="shared")
inp1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096),
T.int64(4096)), scope="shared")
for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
- for ax1_0 in T.thread_binding((m + T.int64(15)) // T.int64(16),
thread="blockIdx.x"):
+ for ax1_0 in T.thread_binding((m + T.int64(31)) // T.int64(32),
thread="blockIdx.x"):
for ax2_0 in T.thread_binding(T.int64(64),
thread="blockIdx.y"):
for ax2_1 in T.thread_binding(T.int64(1),
thread="vthread.y"):
for ax1_1 in T.thread_binding(T.int64(1),
thread="vthread.x"):
for ax2_2 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
for ax1_2 in T.thread_binding(T.int64(8),
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
- for ax2_3_init, ax1_3_init in
T.grid(T.int64(4), T.int64(2)):
+ for ax2_3_init, ax1_3_init in
T.grid(T.int64(4), T.int64(4)):
with T.block("matmul_init"):
v0 = T.axis.spatial(T.int64(1),
ax0)
- v1 = T.axis.spatial((m +
T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + ax1_1 *
T.int64(16) + ax1_2 * T.int64(2) + ax1_3_init)
+ v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 *
T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
v2 = T.axis.spatial(T.int64(4096),
ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
T.reads()
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
for ax3_0 in range(T.int64(256)):
- for ax0_ax1_ax2_fused_0 in
range(T.int64(1)):
- for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_2 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(T.int64(2)):
for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
with
T.block("inp0_reindex_pad_shared"):
v0 =
T.axis.spatial(T.int64(1), T.int64(0))
- v1 =
T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 *
T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(256) + ax0_ax1_ax2_fused_1 *
T.int64(16) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) //
T.int64(16))
- v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ v1 =
T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 *
T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 *
T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) //
T.int64(16))
+ v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(inp0[v0,
v1, v2])
T.writes(inp0_reindex_pad_shared[v0, v1, v2])
-
T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2],
T.float32(0))
- for ax0_ax1_ax2_fused_0 in
range(T.int64(4)):
- for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_2 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(T.int64(4)):
for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
with
T.block("inp1_reindex_shared"):
v0 =
T.axis.spatial(T.int64(1), T.int64(0))
- v1 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 *
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
- v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ v1 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 *
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(inp1[v2,
v1])
T.writes(inp1_reindex_shared[v0, v1, v2])
-
T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
- for ax3_1, ax2_3, ax1_3 in
T.grid(T.int64(16), T.int64(4), T.int64(2)):
+ for ax3_1, ax2_3, ax1_3 in
T.grid(T.int64(16), T.int64(4), T.int64(4)):
with T.block("matmul_update"):
v0 =
T.axis.spatial(T.int64(1), ax0)
- v1 = T.axis.spatial((m +
T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + ax1_1 *
T.int64(16) + ax1_2 * T.int64(2) + ax1_3)
+ v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 *
T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
v2 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2
* T.int64(4) + ax2_3)
v3 =
T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2],
inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0),
v2, v3])
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
matmul_reindex_pad_local[T.int64(0), v1, v2] =
matmul_reindex_pad_local[T.int64(0), v1, v2] +
inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0),
v2, v3]
- for ax0_1, ax1, ax2 in T.grid(T.int64(1),
T.int64(2), T.int64(4)):
- with
T.block("matmul_reindex_pad_local"):
- v0 = T.axis.spatial(T.int64(1),
ax0_1)
- v1 = T.axis.spatial((m +
T.int64(15)) // T.int64(16) * T.int64(16), ax1_0 * T.int64(16) + ax1_2 *
T.int64(2) + ax1)
- v2 = T.axis.spatial(T.int64(4096),
ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2)
-
T.reads(matmul_reindex_pad_local[v0, v1, v2])
- T.writes(matmul[T.int64(0), v1,
v2])
- if v1 < m:
- matmul[T.int64(0), v1, v2] =
matmul_reindex_pad_local[v0, v1, v2]
+ for ax0_1, ax1, ax2_0_1 in
T.grid(T.int64(1), T.int64(4), T.int64(2)):
+ for ax2_1_1 in
T.vectorized(T.int64(2)):
+ with
T.block("matmul_reindex_pad_local"):
+ v0 =
T.axis.spatial(T.int64(1), ax0_1)
+ v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 *
T.int64(4) + ax1)
+ v2 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) +
ax2_0_1 * T.int64(2) + ax2_1_1)
+
T.reads(matmul_reindex_pad_local[v0, v1, v2])
+ T.writes(matmul[T.int64(0),
v1, v2])
+ if v1 < m:
+ matmul[T.int64(0), v1, v2]
= matmul_reindex_pad_local[v0, v1, v2]
# fmt: on
@@ -147,70 +146,71 @@ class TestFusedMatmul(BaseBeforeAfter):
T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0,
v_ax1, v_ax2])
T.writes(Out[v_ax0, v_ax1, v_ax2])
Out[v_ax0, v_ax1, v_ax2] = C[v_ax0, v_ax1, v_ax2] +
var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
-
@T.prim_func
def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S:
T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1),
T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32),
T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32),
T.int64(4096)), "float32")):
T.func_attr({"tir.is_scheduled": 1})
+ # with T.block("root"):
var_matmul_intermediate_reindex_local = T.alloc_buffer((T.int64(1),
T.int64(32), T.int64(4096)), scope="local")
A_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(32),
T.int64(4096)), scope="shared")
var_decode_intermediate_reindex_shared = T.alloc_buffer((T.int64(1),
T.int64(4096), T.int64(4096)), scope="shared")
for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
- for ax1_0 in T.thread_binding(T.int64(2), thread="blockIdx.x"):
+ for ax1_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
for ax2_0 in T.thread_binding(T.int64(64),
thread="blockIdx.y"):
for ax2_1 in T.thread_binding(T.int64(1),
thread="vthread.y"):
for ax1_1 in T.thread_binding(T.int64(1),
thread="vthread.x"):
for ax2_2 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
for ax1_2 in T.thread_binding(T.int64(8),
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
- for ax2_3_init, ax1_3_init in
T.grid(T.int64(4), T.int64(2)):
+ for ax2_3_init, ax1_3_init in
T.grid(T.int64(4), T.int64(4)):
with T.block("matmul_init"):
v0 = T.axis.spatial(T.int64(1),
ax0)
- v1 = T.axis.spatial(T.int64(32),
ax1_0 * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 * T.int64(2) + ax1_3_init)
+ v1 = T.axis.spatial(T.int64(32),
ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
v2 = T.axis.spatial(T.int64(4096),
ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
T.reads()
T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = T.float32(0)
for ax3_0 in range(T.int64(256)):
- for ax0_ax1_ax2_fused_0 in
range(T.int64(1)):
- for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_2 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(T.int64(2)):
for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
with
T.block("A_reindex_shared"):
v0 =
T.axis.spatial(T.int64(1), T.int64(0))
- v1 =
T.axis.spatial(T.int64(32), ax1_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
- v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ v1 =
T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(32) +
ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) +
ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(A[v0, v1,
v2])
T.writes(A_reindex_shared[v0, v1, v2])
-
T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
A_reindex_shared[v0, v1, v2] = A[v0, v1, v2]
- for ax0_ax1_ax2_fused_0 in
range(T.int64(4)):
- for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_2 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(T.int64(4)):
for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
with
T.block("var_decode_intermediate_reindex_shared"):
v0 =
T.axis.spatial(T.int64(1), T.int64(0))
- v1 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 *
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
- v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(256) + ax0_ax1_ax2_fused_1 * T.int64(16) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ v1 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 *
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
T.reads(W[v2 //
T.int64(8), v1], S[v2 // T.int64(32), v1])
T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2])
-
T.block_attr({"buffer_dim_align": [[0, 1, 32, 2]]})
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32",
T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 %
T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32",
T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)),
T.uint32(16))) + T.reinterpret("float32",
T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1],
T.uint32(16)), T.uint32(65535)), T.ui [...]
- for ax3_1, ax2_3, ax1_3 in
T.grid(T.int64(16), T.int64(4), T.int64(2)):
+ for ax3_1, ax2_3, ax1_3 in
T.grid(T.int64(16), T.int64(4), T.int64(4)):
with T.block("matmul_update"):
v0 =
T.axis.spatial(T.int64(1), ax0)
- v1 =
T.axis.spatial(T.int64(32), ax1_0 * T.int64(16) + ax1_1 * T.int64(16) + ax1_2 *
T.int64(2) + ax1_3)
+ v1 =
T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 *
T.int64(4) + ax1_3)
v2 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2
* T.int64(4) + ax2_3)
v3 =
T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
T.reads(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2],
A_reindex_shared[T.int64(0), v1, v3],
var_decode_intermediate_reindex_shared[T.int64(0), v2, v3])
T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] =
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] +
A_reindex_shared[T.int64(0), v1, v3] *
var_decode_intermediate_reindex_shared[T.int64(0), v2, v3]
- for ax0_1, ax1, ax2 in T.grid(T.int64(1),
T.int64(2), T.int64(4)):
- with
T.block("var_matmul_intermediate_reindex_local"):
- v0 = T.axis.spatial(T.int64(1),
ax0_1)
- v1 = T.axis.spatial(T.int64(32),
ax1_0 * T.int64(16) + ax1_2 * T.int64(2) + ax1)
- v2 = T.axis.spatial(T.int64(4096),
ax2_0 * T.int64(64) + ax2_2 * T.int64(4) + ax2)
- T.reads(C[T.int64(0), v1, v2],
var_matmul_intermediate_reindex_local[v0, v1, v2])
- T.writes(Out[T.int64(0), v1, v2])
- Out[T.int64(0), v1, v2] =
C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
+ for ax0_1, ax1, ax2_0_1 in
T.grid(T.int64(1), T.int64(4), T.int64(2)):
+ for ax2_1_1 in
T.vectorized(T.int64(2)):
+ with
T.block("var_matmul_intermediate_reindex_local"):
+ v0 =
T.axis.spatial(T.int64(1), ax0_1)
+ v1 =
T.axis.spatial(T.int64(32), ax1_2 * T.int64(4) + ax1)
+ v2 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) +
ax2_0_1 * T.int64(2) + ax2_1_1)
+ T.reads(C[T.int64(0), v1, v2],
var_matmul_intermediate_reindex_local[v0, v1, v2])
+ T.writes(Out[T.int64(0), v1,
v2])
+ Out[T.int64(0), v1, v2] =
C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
# fmt: on