This is an automated email from the ASF dual-hosted git repository.
expye pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 66d3957d2c [Unity][Dlight] Rule matmul avoiding blockIdx.z (#15333)
66d3957d2c is described below
commit 66d3957d2ce692528d74bfd07e6eebac63afed7d
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jul 17 02:59:46 2023 -0700
[Unity][Dlight] Rule matmul avoiding blockIdx.z (#15333)
Prior to this PR, the matmul rule of dlight binds loops to `blockIdx.z`.
However, not every device supports this blockIdx dimension (for example,
WebGPU does not support `blockIdx.z`), which makes dlight fails to
apply and build.
Therefore, this PR fuses the `blockIdx.z` loop with other `blockIdx`
loop.
---
python/tvm/dlight/gpu/matmul.py | 4 +-
tests/python/dlight/test_gpu_matmul.py | 225 ++++++++++++++++-----------------
2 files changed, 114 insertions(+), 115 deletions(-)
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index be5e4b02d7..b9977d08b9 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -327,8 +327,8 @@ class Matmul(ScheduleRule):
bx, vx, tx, xi = sch.split(x, [None, vthread_x, block_size_x,
micro_size_x])
by, vy, ty, yi = sch.split(y, [None, vthread_y, block_size_y,
micro_size_y])
ko, ki = sch.split(k, factors=[None, micro_size_k])
- sch.reorder(bx, by, vy, vx, ty, tx, ko, ki, yi, xi)
- sch.bind(batch, "blockIdx.z")
+ sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi)
+ by = sch.fuse(batch, by)
sch.bind(bx, "blockIdx.x")
sch.bind(by, "blockIdx.y")
sch.bind(vy, "vthread.y")
diff --git a/tests/python/dlight/test_gpu_matmul.py
b/tests/python/dlight/test_gpu_matmul.py
index f3d9a7089d..318a3e833c 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -57,65 +57,64 @@ class TestMatmul(BaseBeforeAfter):
matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m +
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
inp0_reindex_pad_shared = T.alloc_buffer((T.int64(1), (m +
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="shared")
inp1_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096),
T.int64(4096)), scope="shared")
- for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+ for ax0_ax2_0_fused in T.thread_binding(T.int64(64),
thread="blockIdx.y"):
for ax1_0 in T.thread_binding((m + T.int64(31)) // T.int64(32),
thread="blockIdx.x"):
- for ax2_0 in T.thread_binding(T.int64(64),
thread="blockIdx.y"):
- for ax2_1 in T.thread_binding(T.int64(1),
thread="vthread.y"):
- for ax1_1 in T.thread_binding(T.int64(1),
thread="vthread.x"):
- for ax2_2 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
- for ax1_2 in T.thread_binding(T.int64(8),
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
- for ax2_3_init, ax1_3_init in
T.grid(T.int64(4), T.int64(4)):
- with T.block("matmul_init"):
- v0 = T.axis.spatial(T.int64(1),
ax0)
- v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 *
T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
- v2 = T.axis.spatial(T.int64(4096),
ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
- T.reads()
+ for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+ for ax1_1 in T.thread_binding(T.int64(1),
thread="vthread.x"):
+ for ax2_2 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
+ for ax1_2 in T.thread_binding(T.int64(8),
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
+ for ax2_3_init, ax1_3_init in
T.grid(T.int64(4), T.int64(4)):
+ with T.block("matmul_init"):
+ v0 = T.axis.spatial(T.int64(1),
T.int64(0))
+ v1 = T.axis.spatial((m + T.int64(31))
// T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2
* T.int64(4) + ax1_3_init)
+ v2 = T.axis.spatial(T.int64(4096),
ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) +
ax2_3_init)
+ T.reads()
+
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
+ matmul_reindex_pad_local[T.int64(0),
v1, v2] = T.float32(0)
+ for ax3_0 in range(T.int64(256)):
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(T.int64(2)):
+ for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
+ with
T.block("inp0_reindex_pad_shared"):
+ v0 =
T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial((m
+ T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) +
(ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) +
ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(inp0[v0, v1,
v2])
+
T.writes(inp0_reindex_pad_shared[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+
inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2],
T.float32(0))
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(T.int64(4)):
+ for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
+ with
T.block("inp1_reindex_shared"):
+ v0 =
T.axis.spatial(T.int64(1), T.int64(0))
+ v1 =
T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) +
(ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) +
ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(inp1[v2, v1])
+
T.writes(inp1_reindex_shared[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+
inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
+ for ax3_1, ax2_3, ax1_3 in
T.grid(T.int64(16), T.int64(4), T.int64(4)):
+ with T.block("matmul_update"):
+ v0 = T.axis.spatial(T.int64(1),
T.int64(0))
+ v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 *
T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
+ v2 = T.axis.spatial(T.int64(4096),
ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) +
ax2_3)
+ v3 = T.axis.reduce(T.int64(4096),
ax3_0 * T.int64(16) + ax3_1)
+
T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2],
inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0),
v2, v3])
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
-
matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
- for ax3_0 in range(T.int64(256)):
- for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
- for ax0_ax1_ax2_fused_2 in
range(T.int64(2)):
- for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
- with
T.block("inp0_reindex_pad_shared"):
- v0 =
T.axis.spatial(T.int64(1), T.int64(0))
- v1 =
T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 *
T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 *
T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) //
T.int64(16))
- v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
- T.reads(inp0[v0,
v1, v2])
-
T.writes(inp0_reindex_pad_shared[v0, v1, v2])
-
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
-
inp0_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2],
T.float32(0))
- for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
- for ax0_ax1_ax2_fused_2 in
range(T.int64(4)):
- for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
- with
T.block("inp1_reindex_shared"):
- v0 =
T.axis.spatial(T.int64(1), T.int64(0))
- v1 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 *
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
- v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
- T.reads(inp1[v2,
v1])
-
T.writes(inp1_reindex_shared[v0, v1, v2])
-
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
-
inp1_reindex_shared[v0, v1, v2] = inp1[v2, v1]
- for ax3_1, ax2_3, ax1_3 in
T.grid(T.int64(16), T.int64(4), T.int64(4)):
- with T.block("matmul_update"):
- v0 =
T.axis.spatial(T.int64(1), ax0)
- v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 *
T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
- v2 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2
* T.int64(4) + ax2_3)
- v3 =
T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
-
T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2],
inp0_reindex_pad_shared[T.int64(0), v1, v3], inp1_reindex_shared[T.int64(0),
v2, v3])
-
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
-
matmul_reindex_pad_local[T.int64(0), v1, v2] =
matmul_reindex_pad_local[T.int64(0), v1, v2] +
inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0),
v2, v3]
- for ax0_1, ax1, ax2_0_1 in
T.grid(T.int64(1), T.int64(4), T.int64(2)):
- for ax2_1_1 in
T.vectorized(T.int64(2)):
- with
T.block("matmul_reindex_pad_local"):
- v0 =
T.axis.spatial(T.int64(1), ax0_1)
- v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 *
T.int64(4) + ax1)
- v2 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) +
ax2_0_1 * T.int64(2) + ax2_1_1)
-
T.reads(matmul_reindex_pad_local[v0, v1, v2])
- T.writes(matmul[T.int64(0),
v1, v2])
- if v1 < m:
- matmul[T.int64(0), v1, v2]
= matmul_reindex_pad_local[v0, v1, v2]
+
matmul_reindex_pad_local[T.int64(0), v1, v2] =
matmul_reindex_pad_local[T.int64(0), v1, v2] +
inp0_reindex_pad_shared[T.int64(0), v1, v3] * inp1_reindex_shared[T.int64(0),
v2, v3]
+ for ax0, ax1, ax2_0 in T.grid(T.int64(1),
T.int64(4), T.int64(2)):
+ for ax2_1_1 in T.vectorized(T.int64(2)):
+ with
T.block("matmul_reindex_pad_local"):
+ v0 = T.axis.spatial(T.int64(1),
ax0)
+ v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 *
T.int64(4) + ax1)
+ v2 = T.axis.spatial(T.int64(4096),
ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) +
ax2_1_1)
+
T.reads(matmul_reindex_pad_local[v0, v1, v2])
+ T.writes(matmul[T.int64(0), v1,
v2])
+ if v1 < m:
+ matmul[T.int64(0), v1, v2] =
matmul_reindex_pad_local[v0, v1, v2]
# fmt: on
@@ -146,6 +145,7 @@ class TestFusedMatmul(BaseBeforeAfter):
T.reads(C[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0,
v_ax1, v_ax2])
T.writes(Out[v_ax0, v_ax1, v_ax2])
Out[v_ax0, v_ax1, v_ax2] = C[v_ax0, v_ax1, v_ax2] +
var_matmul_intermediate[v_ax0, v_ax1, v_ax2]
+
@T.prim_func
def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S:
T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1),
T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32),
T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32),
T.int64(4096)), "float32")):
T.func_attr({"tir.is_scheduled": 1})
@@ -153,64 +153,63 @@ class TestFusedMatmul(BaseBeforeAfter):
var_matmul_intermediate_reindex_local = T.alloc_buffer((T.int64(1),
T.int64(32), T.int64(4096)), scope="local")
A_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(32),
T.int64(4096)), scope="shared")
var_decode_intermediate_reindex_shared = T.alloc_buffer((T.int64(1),
T.int64(4096), T.int64(4096)), scope="shared")
- for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+ for ax0_ax2_0_fused in T.thread_binding(T.int64(64),
thread="blockIdx.y"):
for ax1_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"):
- for ax2_0 in T.thread_binding(T.int64(64),
thread="blockIdx.y"):
- for ax2_1 in T.thread_binding(T.int64(1),
thread="vthread.y"):
- for ax1_1 in T.thread_binding(T.int64(1),
thread="vthread.x"):
- for ax2_2 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
- for ax1_2 in T.thread_binding(T.int64(8),
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
- for ax2_3_init, ax1_3_init in
T.grid(T.int64(4), T.int64(4)):
- with T.block("matmul_init"):
- v0 = T.axis.spatial(T.int64(1),
ax0)
- v1 = T.axis.spatial(T.int64(32),
ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
- v2 = T.axis.spatial(T.int64(4096),
ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init)
- T.reads()
+ for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+ for ax1_1 in T.thread_binding(T.int64(1),
thread="vthread.x"):
+ for ax2_2 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
+ for ax1_2 in T.thread_binding(T.int64(8),
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
+ for ax2_3_init, ax1_3_init in
T.grid(T.int64(4), T.int64(4)):
+ with T.block("matmul_init"):
+ v0 = T.axis.spatial(T.int64(1),
T.int64(0))
+ v1 = T.axis.spatial(T.int64(32), ax1_0
* T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_init)
+ v2 = T.axis.spatial(T.int64(4096),
ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) +
ax2_3_init)
+ T.reads()
+
T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
+
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = T.float32(0)
+ for ax3_0 in range(T.int64(256)):
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(T.int64(2)):
+ for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
+ with
T.block("A_reindex_shared"):
+ v0 =
T.axis.spatial(T.int64(1), T.int64(0))
+ v1 =
T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(32) +
ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) +
ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(A[v0, v1, v2])
+
T.writes(A_reindex_shared[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+ A_reindex_shared[v0,
v1, v2] = A[v0, v1, v2]
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(T.int64(4)):
+ for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
+ with
T.block("var_decode_intermediate_reindex_shared"):
+ v0 =
T.axis.spatial(T.int64(1), T.int64(0))
+ v1 =
T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) +
(ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) +
ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(W[v2 //
T.int64(8), v1], S[v2 // T.int64(32), v1])
+
T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+
var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32",
T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 %
T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32",
T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)),
T.uint32(16))) + T.reinterpret("float32",
T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1],
T.uint32(16)), T.uint32(65535)), T.uint32(16)))
+ for ax3_1, ax2_3, ax1_3 in
T.grid(T.int64(16), T.int64(4), T.int64(4)):
+ with T.block("matmul_update"):
+ v0 = T.axis.spatial(T.int64(1),
T.int64(0))
+ v1 = T.axis.spatial(T.int64(32),
ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
+ v2 = T.axis.spatial(T.int64(4096),
ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) +
ax2_3)
+ v3 = T.axis.reduce(T.int64(4096),
ax3_0 * T.int64(16) + ax3_1)
+
T.reads(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2],
A_reindex_shared[T.int64(0), v1, v3],
var_decode_intermediate_reindex_shared[T.int64(0), v2, v3])
T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
-
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] = T.float32(0)
- for ax3_0 in range(T.int64(256)):
- for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
- for ax0_ax1_ax2_fused_2 in
range(T.int64(2)):
- for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
- with
T.block("A_reindex_shared"):
- v0 =
T.axis.spatial(T.int64(1), T.int64(0))
- v1 =
T.axis.spatial(T.int64(32), (ax0_ax1_ax2_fused_0 * T.int64(32) +
ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) +
ax0_ax1_ax2_fused_3) // T.int64(16))
- v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
- T.reads(A[v0, v1,
v2])
-
T.writes(A_reindex_shared[v0, v1, v2])
-
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
-
A_reindex_shared[v0, v1, v2] = A[v0, v1, v2]
- for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
- for ax0_ax1_ax2_fused_2 in
range(T.int64(4)):
- for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
- with
T.block("var_decode_intermediate_reindex_shared"):
- v0 =
T.axis.spatial(T.int64(1), T.int64(0))
- v1 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + (ax0_ax1_ax2_fused_0 *
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
- v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
- T.reads(W[v2 //
T.int64(8), v1], S[v2 // T.int64(32), v1])
-
T.writes(var_decode_intermediate_reindex_shared[v0, v1, v2])
-
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
-
var_decode_intermediate_reindex_shared[v0, v1, v2] = T.Cast("float32",
T.bitwise_and(T.shift_right(W[v2 // T.int64(8), v1], T.Cast("uint32", v2 %
T.int64(8) * T.int64(4))), T.uint32(15))) * T.reinterpret("float32",
T.shift_left(T.bitwise_and(S[v2 // T.int64(32), v1], T.uint32(65535)),
T.uint32(16))) + T.reinterpret("float32",
T.shift_left(T.bitwise_and(T.shift_right(S[v2 // T.int64(32), v1],
T.uint32(16)), T.uint32(65535)), T.ui [...]
- for ax3_1, ax2_3, ax1_3 in
T.grid(T.int64(16), T.int64(4), T.int64(4)):
- with T.block("matmul_update"):
- v0 =
T.axis.spatial(T.int64(1), ax0)
- v1 =
T.axis.spatial(T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 *
T.int64(4) + ax1_3)
- v2 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2
* T.int64(4) + ax2_3)
- v3 =
T.axis.reduce(T.int64(4096), ax3_0 * T.int64(16) + ax3_1)
-
T.reads(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2],
A_reindex_shared[T.int64(0), v1, v3],
var_decode_intermediate_reindex_shared[T.int64(0), v2, v3])
-
T.writes(var_matmul_intermediate_reindex_local[T.int64(0), v1, v2])
-
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] =
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] +
A_reindex_shared[T.int64(0), v1, v3] *
var_decode_intermediate_reindex_shared[T.int64(0), v2, v3]
- for ax0_1, ax1, ax2_0_1 in
T.grid(T.int64(1), T.int64(4), T.int64(2)):
- for ax2_1_1 in
T.vectorized(T.int64(2)):
- with
T.block("var_matmul_intermediate_reindex_local"):
- v0 =
T.axis.spatial(T.int64(1), ax0_1)
- v1 =
T.axis.spatial(T.int64(32), ax1_2 * T.int64(4) + ax1)
- v2 =
T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(4) +
ax2_0_1 * T.int64(2) + ax2_1_1)
- T.reads(C[T.int64(0), v1, v2],
var_matmul_intermediate_reindex_local[v0, v1, v2])
- T.writes(Out[T.int64(0), v1,
v2])
- Out[T.int64(0), v1, v2] =
C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
+
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] =
var_matmul_intermediate_reindex_local[T.int64(0), v1, v2] +
A_reindex_shared[T.int64(0), v1, v3] *
var_decode_intermediate_reindex_shared[T.int64(0), v2, v3]
+ for ax0, ax1, ax2_0 in T.grid(T.int64(1),
T.int64(4), T.int64(2)):
+ for ax2_1_1 in T.vectorized(T.int64(2)):
+ with
T.block("var_matmul_intermediate_reindex_local"):
+ v0 = T.axis.spatial(T.int64(1),
ax0)
+ v1 = T.axis.spatial(T.int64(32),
ax1_2 * T.int64(4) + ax1)
+ v2 = T.axis.spatial(T.int64(4096),
ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) +
ax2_1_1)
+ T.reads(C[T.int64(0), v1, v2],
var_matmul_intermediate_reindex_local[v0, v1, v2])
+ T.writes(Out[T.int64(0), v1, v2])
+ Out[T.int64(0), v1, v2] =
C[T.int64(0), v1, v2] + var_matmul_intermediate_reindex_local[v0, v1, v2]
# fmt: on