This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7d98f60144 [Unity][Dlight] Fix matmul schedule when out_dtype = fp32
and bias add is fp32 (#15363)
7d98f60144 is described below
commit 7d98f601442c9dd1ad02355704a0ac3ce09d35db
Author: masahi <[email protected]>
AuthorDate: Thu Jul 20 16:39:27 2023 +0900
[Unity][Dlight] Fix matmul schedule when out_dtype = fp32 and bias add is
fp32 (#15363)
* workaround for unscheduled cast block
* clean test
* comment
---
python/tvm/dlight/gpu/matmul.py | 17 +++++
tests/python/dlight/test_gpu_matmul.py | 129 +++++++++++++++++++++++++++++++++
2 files changed, 146 insertions(+)
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index b9977d08b9..def13a60ac 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -365,5 +365,22 @@ class Matmul(ScheduleRule):
auto_inline_producers(sch, a_g2s)
auto_inline_producers(sch, b_g2s)
auto_inline_consumers(sch, l2g)
+
+ remaining_consumers = sch.get_consumers(l2g)
+
+ if len(remaining_consumers) != 0:
+ # Some blocks have failed to be inlined to the producer
cache-write stage.
+ # This could be due to another producer block that has not been
scheduled.
+ for c in remaining_consumers:
+ for p in sch.get_producers(c):
+ if sch.get(p) != sch.get(l2g):
+ sch.compute_inline(p)
+
+ # Try inlining into the cache-write stage again, this time it
should succeed.
+ auto_inline_consumers(sch, l2g)
+
+ msg = "There are some consumers of the cache-write stage that are not
properly inlined."
+ assert len(sch.get_consumers(l2g)) == 0, msg
+
sch.decompose_reduction(main_block, ko)
return sch
diff --git a/tests/python/dlight/test_gpu_matmul.py
b/tests/python/dlight/test_gpu_matmul.py
index 318a3e833c..b9ee95b76b 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -19,6 +19,7 @@ import pytest
import tvm.testing
from tvm import dlight as dl
+from tvm.script import ir as I
from tvm.script import tir as T
from tvm.target import Target
@@ -247,5 +248,133 @@ class TestSkipGEMV(BaseBeforeAfter):
expected = before
+class TestOutputFP32(BaseBeforeAfter):
+ # fmt: off
+
+ @T.prim_func
+ def before(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14:
T.Buffer((T.int64(4096), T.int64(128)), "float16"), p_lv48: T.handle, lv13_1:
T.Buffer((T.int64(4096),), "float16"), p_lv3: T.handle, p_output0: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n = T.int64()
+ lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(4096)),
"float16")
+ lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16")
+ p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n,
T.int64(4096)), "float16")
+ # with T.block("root"):
+ p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096),
T.int64(4096)), "float16")
+ var_matmul_intermediate = T.alloc_buffer((T.int64(1), n,
T.int64(4096)))
+ var_compute_intermediate = T.alloc_buffer((T.int64(4096),))
+ var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)))
+ var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), n,
T.int64(4096)), "float16")
+ for i, j in T.grid(T.int64(4096), T.int64(4096)):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(lv13[v_i, v_j // T.int64(8)], lv14[v_i, v_j //
T.int64(32)])
+ T.writes(p_output0_intermediate_1[v_i, v_j])
+ p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16",
T.bitwise_and(T.shift_right(lv13[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j
% T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv14[v_i, v_j //
T.int64(32)]
+ for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096),
T.int64(4096)):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ T.reads(lv48[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k,
v_i2])
+ T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+ with T.init():
+ var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
+ var_matmul_intermediate[v_i0, v_i1, v_i2] =
var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv48[v_i0, v_i1,
v_k]) * T.Cast("float32", p_output0_intermediate_1[v_k, v_i2])
+ for i0 in range(T.int64(4096)):
+ with T.block("compute"):
+ v_i0 = T.axis.spatial(T.int64(4096), i0)
+ T.reads(lv13_1[v_i0])
+ T.writes(var_compute_intermediate[v_i0])
+ var_compute_intermediate[v_i0] = T.Cast("float32",
lv13_1[v_i0])
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2],
var_compute_intermediate[v_ax2])
+ T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
+ var_T_add_intermediate[v_ax0, v_ax1, v_ax2] =
var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + var_compute_intermediate[v_ax2]
+ for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(4096)):
+ with T.block("compute_1"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2])
+ T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2])
+ var_compute_intermediate_1[v_i0, v_i1, v_i2] =
T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2])
+ for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
+ with T.block("T_add_1"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2],
lv3[v_ax0, v_ax1, v_ax2])
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] =
var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv3[v_ax0, v_ax1, v_ax2]
+
+ @T.prim_func
+ def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"),
lv14: T.Buffer((T.int64(4096), T.int64(128)), "float16"), p_lv48: T.handle,
lv13_1: T.Buffer((T.int64(4096),), "float16"), p_lv3: T.handle, p_output0:
T.handle):
+ T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ n = T.int64()
+ lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(4096)),
"float16")
+ lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16")
+ p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n,
T.int64(4096)), "float16")
+ # with T.block("root"):
+ var_matmul_intermediate_reindex_pad_local =
T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32),
T.int64(4096)), scope="local")
+ lv48_reindex_pad_shared = T.alloc_buffer((T.int64(1), (n +
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16",
scope="shared")
+ p_output0_intermediate_1_reindex_shared = T.alloc_buffer((T.int64(1),
T.int64(4096), T.int64(4096)), "float16", scope="shared")
+ for ax0_ax2_0_fused in T.thread_binding(T.int64(64),
thread="blockIdx.y"):
+ for ax1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32),
thread="blockIdx.x"):
+ for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+ for ax1_1 in T.thread_binding(T.int64(1),
thread="vthread.x"):
+ for ax2_2 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
+ for ax1_2 in T.thread_binding(T.int64(8),
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
+ for ax2_3_init, ax1_3_init in
T.grid(T.int64(4), T.int64(4)):
+ with T.block("matmul_init"):
+ v0 = T.axis.spatial(T.int64(1),
T.int64(0))
+ v1 = T.axis.spatial((n + T.int64(31))
// T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2
* T.int64(4) + ax1_3_init)
+ v2 = T.axis.spatial(T.int64(4096),
ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) +
ax2_3_init)
+ T.reads()
+
T.writes(var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
+
var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
+ for ax3_0 in range(T.int64(256)):
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(T.int64(2)):
+ for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
+ with
T.block("lv48_reindex_pad_shared"):
+ v0 =
T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial((n
+ T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) +
(ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) +
ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(lv48[v0, v1,
v2])
+
T.writes(lv48_reindex_pad_shared[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+
lv48_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv48[v0, v1, v2],
T.float16(0))
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(T.int64(4)):
+ for ax0_ax1_ax2_fused_3 in
T.vectorized(T.int64(2)):
+ with
T.block("p_output0_intermediate_1_reindex_shared"):
+ v0 =
T.axis.spatial(T.int64(1), T.int64(0))
+ v1 =
T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) +
(ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) +
ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+ v2 =
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 *
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 *
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+ T.reads(lv13[v2, v1 //
T.int64(8)], lv14[v2, v1 // T.int64(32)])
+
T.writes(p_output0_intermediate_1_reindex_shared[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+
p_output0_intermediate_1_reindex_shared[v0, v1, v2] = (T.Cast("float16",
T.bitwise_and(T.shift_right(lv13[v2, v1 // T.int64(8)], T.Cast("uint32", v1 %
T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv14[v2, v1 //
T.int64(32)]
+ for ax3_1, ax2_3, ax1_3 in
T.grid(T.int64(16), T.int64(4), T.int64(4)):
+ with T.block("matmul_update"):
+ v0 = T.axis.spatial(T.int64(1),
T.int64(0))
+ v1 = T.axis.spatial((n +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 *
T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
+ v2 = T.axis.spatial(T.int64(4096),
ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) +
ax2_3)
+ v3 = T.axis.reduce(T.int64(4096),
ax3_0 * T.int64(16) + ax3_1)
+
T.reads(var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2],
lv48_reindex_pad_shared[T.int64(0), v1, v3],
p_output0_intermediate_1_reindex_shared[T.int64(0), v2, v3])
+
T.writes(var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
+
var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] =
var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] +
T.Cast("float32", lv48_reindex_pad_shared[T.int64(0), v1, v3]) *
T.Cast("float32", p_output0_intermediate_1_reindex_shared[T.int64(0), v2, v3])
+ for ax0, ax1, ax2_0 in T.grid(T.int64(1),
T.int64(4), T.int64(2)):
+ for ax2_1_1 in T.vectorized(T.int64(2)):
+ with
T.block("var_matmul_intermediate_reindex_pad_local"):
+ v0 = T.axis.spatial(T.int64(1),
ax0)
+ v1 = T.axis.spatial((n +
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 *
T.int64(4) + ax1)
+ v2 = T.axis.spatial(T.int64(4096),
ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) +
ax2_1_1)
+
T.reads(var_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv13_1[v2],
lv3[T.int64(0), v1, v2])
+
T.writes(p_output0_intermediate[T.int64(0), v1, v2])
+ if v1 < n:
+
p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16",
var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32",
lv13_1[v2])) + lv3[T.int64(0), v1, v2]
+
+ # fmt: on
+
+
if __name__ == "__main__":
tvm.testing.main()