This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 209971a62e [DLIGHT][GPU] Improved gemv outer fallback schedule (#16973)
209971a62e is described below
commit 209971a62edf4a6ea6c628ef8399e45e926e727c
Author: krishnaraj36 <[email protected]>
AuthorDate: Tue May 21 14:24:53 2024 +0530
[DLIGHT][GPU] Improved gemv outer fallback schedule (#16973)
* [DLIGHT][GPU] Improved gemv outer fallback schedule
Improved the gemv outer fallback schedules. It improved
few gemv kernel by 20%.
* Fix lint error
* Fix the gemv schedule params for dynamic vocab_size kernel
---
python/tvm/dlight/gpu/gemv.py | 39 ++++++++----
tests/python/dlight/test_gpu_gemv.py | 113 +++++++++++++++++++----------------
2 files changed, 91 insertions(+), 61 deletions(-)
diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py
index cbef6235c0..da6a4ef834 100644
--- a/python/tvm/dlight/gpu/gemv.py
+++ b/python/tvm/dlight/gpu/gemv.py
@@ -463,6 +463,8 @@ class GEMV(GPUScheduleRule):
TS, TR = 4, 64
else:
TS, TR = 16, 32
+ else:
+ TS, TR = 1, 64
elif target.kind.name == "metal":
# Note that the following tile size is tuned on M2 Ultra for 7B
TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
@@ -476,6 +478,8 @@ class GEMV(GPUScheduleRule):
TS, TR = 4, 16
else:
TS, TR = 2, 64
+ else:
+ TS, TR = 1, 64
elif target.kind.name == "rocm":
VEC_C = 4
# TODO: set LOAD_V_SHARED = False for now
@@ -489,13 +493,15 @@ class GEMV(GPUScheduleRule):
TS, TR = 1, 128
else:
TS, TR = 8, 64
+ else:
+ TS, TR = 1, 64
elif target.kind.name == "opencl" and "android" in str(target.host):
TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
VEC_C = 8
LOAD_V_SHARED = False
LOAD_V_VEC = -1
UNROLL = 8
- TS, TR = 2, 64
+ TS, TR = 2, 32
elif target.kind.name == "vulkan":
VEC_C = 4
LOAD_V_SHARED = True
@@ -506,6 +512,8 @@ class GEMV(GPUScheduleRule):
TS, TR = 4, 32
else:
TS, TR = 16, 32
+ else:
+ TS, TR = 1, 64
elif target.kind.name == "opencl" and "mali" in str(target.attrs):
VEC_C = 8
LOAD_V_SHARED = False
@@ -519,9 +527,6 @@ class GEMV(GPUScheduleRule):
UNROLL = 64
TS, TR = 1, 64
- if not isinstance(len_S, int):
- TS, TR = 1, 64
-
while TS * TR > target.max_num_threads:
if TS > 1:
TS //= 2
@@ -709,7 +714,11 @@ class GEMV(GPUScheduleRule):
if not isinstance(len_r, int):
return None
- if isinstance(len_s, int) and len_s > 32000:
+ if not isinstance(len_s, int):
+ TS, TR = 256, 1
+ LOAD_V_SHARED = True
+
+ if isinstance(len_s, int) and len_s > 96000:
return None
_, TILE_R = (
@@ -754,7 +763,8 @@ class GEMV(GPUScheduleRule):
len_s = get_extent(sch, s)
# The config is designed for Adreno
- tx_len = 64
+ LOAD_V_SHARED = 1
+ tx_len = 128
vec_len = (4 if len_s > 4096 else 2) if isinstance(len_s, int) else 1
inner_r = 4
@@ -768,16 +778,23 @@ class GEMV(GPUScheduleRule):
sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=8)
sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1)
- cache_v = sch.cache_read(block, vector_input_buffers[0], "local")
- sch.compute_at(cache_v, r1, preserve_unit_loops=True)
- sch.vectorize(sch.get_loops(cache_v)[-1])
+ if LOAD_V_SHARED:
+ V_shared = sch.cache_read(block, vector_input_buffers[0],
storage_scope="shared")
+ sch.compute_at(V_shared, bx, preserve_unit_loops=True)
+ l = sch.get_loops(block=V_shared)[-1]
+ _, tx, vec_r = sch.split(l, factors=[None, tx_len, 8],
preserve_unit_iters=True)
+ sch.bind(tx, "threadIdx.x")
+ sch.vectorize(vec_r)
sch.vectorize(vec)
# Schedule epilogue
if epilogue_info is not None:
- sch.reverse_compute_at(epilogue_info.block_rv, tx)
-
+ sch.reverse_compute_at(epilogue_info.block_rv, bx,
preserve_unit_loops=True)
+ ts_tile_s = sch.get_loops(epilogue_info.block_rv)[-1]
+ ts, vec = sch.split(ts_tile_s, factors=[tx_len, vec_len],
preserve_unit_iters=True)
+ sch.bind(ts, "threadIdx.x")
+ sch.vectorize(vec)
sch.set_scope(block, 0, "local")
sch.decompose_reduction(block, r0)
diff --git a/tests/python/dlight/test_gpu_gemv.py
b/tests/python/dlight/test_gpu_gemv.py
index 4aae617654..0f7b6f45ae 100644
--- a/tests/python/dlight/test_gpu_gemv.py
+++ b/tests/python/dlight/test_gpu_gemv.py
@@ -1106,82 +1106,95 @@ def test_outer_reduction_adreno_dynamic():
p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1),
T.int64(1), v))
# with T.block("root"):
var_matmul_intermediate_local = T.alloc_buffer((T.int64(1),
T.int64(1), v), "float16", scope="local")
- var_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(32),
T.int64(1), T.int64(1), v), "float16", scope="local")
- var_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(4),
T.int64(1), T.int64(1), v), "float16", scope="local")
+ var_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(8),
T.int64(1), T.int64(1), v), "float16", scope="local")
+ var_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(1),
T.int64(1), T.int64(1), v), "float16", scope="local")
lv613_local = T.alloc_buffer((T.int64(128), v), "float16",
scope="local")
lv612_local = T.alloc_buffer((T.int64(512), v), "uint32",
scope="local")
- for u_fused_ax0_fused_fused_0 in T.thread_binding((v + T.int64(63)) //
T.int64(64), thread="blockIdx.x"):
- for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(64),
thread="threadIdx.x"):
- for
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in
T.thread_binding(T.int64(4), thread="threadIdx.y"):
+ lv1607_shared = T.alloc_buffer((T.int64(1), T.int64(1),
T.int64(4096)), "float16", scope="shared")
+ for u_fused_ax0_fused_fused_0 in T.thread_binding((v + T.int64(255))
// T.int64(256), thread="blockIdx.x"):
+ for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ for
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init in
T.thread_binding(T.int64(1), thread="threadIdx.y"):
for
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init in
T.vectorized(T.int64(8)):
with T.block("matmul_rf_init"):
-
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused =
T.axis.spatial(T.int64(32),
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * T.int64(8) +
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init)
- v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 *
T.int64(64) + u_fused_ax0_fused_fused_1)
- T.where(u_fused_ax0_fused_fused_0 * T.int64(64) +
u_fused_ax0_fused_fused_1 < v)
+
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused =
T.axis.spatial(T.int64(8),
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0_init * T.int64(8) +
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1_init)
+ v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 *
T.int64(256) + u_fused_ax0_fused_fused_1)
+ T.where(u_fused_ax0_fused_fused_0 * T.int64(256) +
u_fused_ax0_fused_fused_1 < v)
T.reads()
T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
T.int64(0), T.int64(0), v0])
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
T.int64(0), T.int64(0), v0] = T.float16(0)
- for
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in
T.thread_binding(T.int64(4), thread="threadIdx.y"):
- for ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1
in T.grid(T.int64(32), T.int64(1)):
- for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
- with T.block("lv613_local"):
- v0 = T.axis.spatial(T.int64(128),
ax1_0_fused_ax1_1_fused_0 * T.int64(4) +
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 + ax0)
- v1 = T.axis.spatial(v,
u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 + ax1)
- T.where(u_fused_ax0_fused_fused_0 *
T.int64(64) + u_fused_ax0_fused_fused_1 < v)
- T.reads(lv613[v0, v1])
- T.writes(lv613_local[v0, v1])
- lv613_local[v0, v1] = lv613[v0, v1]
- for ax1_0_fused_ax1_1_fused_3 in range(T.int64(4)):
+ for
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 in
T.thread_binding(T.int64(1), thread="threadIdx.y"):
+ for ax1_0_fused_ax1_1_fused_0 in range(T.int64(128)):
+ for ax0, ax1, ax2_0, ax2_1 in T.grid(T.int64(1),
T.int64(1), T.int64(1), T.int64(1)):
+ for ax2_2 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
+ for ax2_3 in T.thread_binding(T.int64(1),
thread="threadIdx.y"):
+ for ax2_4 in T.vectorized(T.int64(4)):
+ with T.block("lv1607_shared"):
+ v0, v1 = T.axis.remap("SS", [ax0,
ax1])
+ v2 = T.axis.spatial(T.int64(4096),
ax1_0_fused_ax1_1_fused_0 * T.int64(32) + (ax2_0 * T.int64(1024) + ax2_1 *
T.int64(1024) + ax2_2 * T.int64(4) + ax2_3 * T.int64(4) + ax2_4))
+ T.where(((ax2_0 + ax2_1) *
T.int64(256) + ax2_2 + ax2_3) * T.int64(4) + ax2_4 < T.int64(32))
+ T.reads(lv1607[v0, v1, v2])
+ T.writes(lv1607_shared[v0, v1, v2])
+ lv1607_shared[v0, v1, v2] =
lv1607[v0, v1, v2]
+ for ax1_0_fused_ax1_1_fused_1 in range(T.int64(1)):
for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
- with T.block("lv612_local"):
- v0 = T.axis.spatial(T.int64(512),
ax1_0_fused_ax1_1_fused_0 * T.int64(16) +
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(4) +
ax1_0_fused_ax1_1_fused_3 + ax0)
- v1 = T.axis.spatial(v,
u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1 + ax1)
- T.where(u_fused_ax0_fused_fused_0 *
T.int64(64) + u_fused_ax0_fused_fused_1 < v)
- T.reads(lv612[v0, v1])
- T.writes(lv612_local[v0, v1])
- lv612_local[v0, v1] = lv612[v0, v1]
- for
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in
T.vectorized(T.int64(8)):
- with T.block("matmul_rf_update"):
-
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused =
T.axis.spatial(T.int64(32),
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) +
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1)
- v0 = T.axis.spatial(v,
u_fused_ax0_fused_fused_0 * T.int64(64) + u_fused_ax0_fused_fused_1)
- vax1_0_fused_ax1_1_fused_0,
vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR",
[ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1,
ax1_0_fused_ax1_1_fused_3])
- T.where(u_fused_ax0_fused_fused_0 *
T.int64(64) + u_fused_ax0_fused_fused_1 < v)
-
T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
T.int64(0), T.int64(0), v0], lv1607[T.int64(0), T.int64(0),
vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 *
T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused //
T.int64(8) * T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) +
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused % T.int64(8)], lv
[...]
-
T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
T.int64(0), T.int64(0), v0])
-
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
T.int64(0), T.int64(0), v0] =
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
T.int64(0), T.int64(0), v0] + lv1607[T.int64(0), T.int64(0),
vax1_0_fused_ax1_1_fused_0 * T.int64(128) + vax1_0_fused_ax1_1_fused_1 *
T.int64(128) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused //
T.int64(8) * T.int64(32) + va [...]
- for ax2 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
- for ax0 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
+ with T.block("lv613_local"):
+ v0 = T.axis.spatial(T.int64(128),
ax1_0_fused_ax1_1_fused_0 + ax0)
+ v1 = T.axis.spatial(v,
u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 + ax1)
+ T.where(u_fused_ax0_fused_fused_0 *
T.int64(256) + u_fused_ax0_fused_fused_1 < v)
+ T.reads(lv613[v0, v1])
+ T.writes(lv613_local[v0, v1])
+ lv613_local[v0, v1] = lv613[v0, v1]
+ for ax1_0_fused_ax1_1_fused_3 in range(T.int64(4)):
+ for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
+ with T.block("lv612_local"):
+ v0 = T.axis.spatial(T.int64(512),
ax1_0_fused_ax1_1_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_3 + ax0)
+ v1 = T.axis.spatial(v,
u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1 + ax1)
+ T.where(u_fused_ax0_fused_fused_0 *
T.int64(256) + u_fused_ax0_fused_fused_1 < v)
+ T.reads(lv612[v0, v1])
+ T.writes(lv612_local[v0, v1])
+ lv612_local[v0, v1] = lv612[v0, v1]
+ for
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 in
T.vectorized(T.int64(8)):
+ with T.block("matmul_rf_update"):
+
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused =
T.axis.spatial(T.int64(8),
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 * T.int64(8) +
ax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1)
+ v0 = T.axis.spatial(v,
u_fused_ax0_fused_fused_0 * T.int64(256) + u_fused_ax0_fused_fused_1)
+ vax1_0_fused_ax1_1_fused_0,
vax1_0_fused_ax1_1_fused_1, vax1_0_fused_ax1_1_fused_3 = T.axis.remap("RRR",
[ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_1,
ax1_0_fused_ax1_1_fused_3])
+ T.where(u_fused_ax0_fused_fused_0 *
T.int64(256) + u_fused_ax0_fused_fused_1 < v)
+
T.reads(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
T.int64(0), T.int64(0), v0], lv1607_shared[T.int64(0), T.int64(0),
vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1 *
T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) +
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused],
lv612_local[vax1_0_fused_ax1_1_fused_0 * T.int64(4) +
vax1_0_fused_ax1_1_fused_1 * T.int64(4) + [...]
+
T.writes(var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
T.int64(0), T.int64(0), v0])
+
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
T.int64(0), T.int64(0), v0] =
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused,
T.int64(0), T.int64(0), v0] + lv1607_shared[T.int64(0), T.int64(0),
vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1 *
T.int64(32) + vax1_0_fused_ax1_1_fused_3 * T.int64(8) +
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_a [...]
+ for ax2 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
+ for ax0 in T.thread_binding(T.int64(1), thread="threadIdx.y"):
with T.block("matmul_rf_init"):
-
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 =
T.axis.spatial(T.int64(4), ax0)
- v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 *
T.int64(64) + ax2)
- T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax2
< v)
+
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 =
T.axis.spatial(T.int64(1), ax0)
+ v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 *
T.int64(256) + ax2)
+ T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + ax2
< v)
T.reads()
T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
T.int64(0), T.int64(0), v0])
var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
T.int64(0), T.int64(0), v0] = T.float16(0)
for ax1 in T.serial(T.int64(8),
annotations={"pragma_auto_unroll_max_step": 8, "pragma_unroll_explicit": 1}):
with T.block("matmul_rf_update"):
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1 =
T.axis.remap("SR", [ax0, ax1])
- v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 *
T.int64(64) + ax2)
- T.where(u_fused_ax0_fused_fused_0 * T.int64(64) +
ax2 < v)
+ v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 *
T.int64(256) + ax2)
+ T.where(u_fused_ax0_fused_fused_0 * T.int64(256) +
ax2 < v)
T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
T.int64(0), T.int64(0), v0],
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0
* T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1,
T.int64(0), T.int64(0), v0])
T.writes(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
T.int64(0), T.int64(0), v0])
var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
T.int64(0), T.int64(0), v0] =
var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
T.int64(0), T.int64(0), v0] +
var_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0
* T.int64(8) + vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_1,
T.int64(0), T.int64(0), v0]
- for ax1 in T.thread_binding(T.int64(64), thread="threadIdx.x"):
- for ax0 in T.thread_binding(T.int64(4), thread="threadIdx.y"):
+ for ax1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
+ for ax0 in T.thread_binding(T.int64(1), thread="threadIdx.y"):
with T.block("matmul"):
-
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 =
T.axis.reduce(T.int64(4), ax0)
- v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 *
T.int64(64) + ax1)
- T.where(u_fused_ax0_fused_fused_0 * T.int64(64) + ax1
< v)
+
vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0 =
T.axis.reduce(T.int64(1), ax0)
+ v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 *
T.int64(256) + ax1)
+ T.where(u_fused_ax0_fused_fused_0 * T.int64(256) + ax1
< v)
T.reads(var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
T.int64(0), T.int64(0), v0])
T.writes(var_matmul_intermediate_local[T.int64(0),
T.int64(0), v0])
with T.init():
var_matmul_intermediate_local[T.int64(0),
T.int64(0), v0] = T.float16(0)
var_matmul_intermediate_local[T.int64(0), T.int64(0),
v0] = var_matmul_intermediate_local[T.int64(0), T.int64(0), v0] +
var_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_2_ax1_0_fused_ax1_1_fused_4_fused_0,
T.int64(0), T.int64(0), v0]
- for ax0_fused_0 in T.thread_binding(T.int64(64),
thread="threadIdx.x"):
+ for ax0_fused_0 in T.thread_binding(T.int64(256),
thread="threadIdx.x"):
for ax0_fused_1 in range(T.int64(1)):
with T.block("compute"):
- v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 *
T.int64(64) + ax0_fused_0 + ax0_fused_1)
- T.where(u_fused_ax0_fused_fused_0 * T.int64(64) +
(ax0_fused_0 + ax0_fused_1) < v)
+ v0 = T.axis.spatial(v, u_fused_ax0_fused_fused_0 *
T.int64(256) + ax0_fused_0 + ax0_fused_1)
+ T.where(u_fused_ax0_fused_fused_0 * T.int64(256) +
(ax0_fused_0 + ax0_fused_1) < v)
T.reads(var_matmul_intermediate_local[T.int64(0),
T.int64(0), v0])
T.writes(p_output0_intermediate[T.int64(0),
T.int64(0), v0])
p_output0_intermediate[T.int64(0), T.int64(0), v0] =
T.Cast("float32", var_matmul_intermediate_local[T.int64(0), T.int64(0), v0])