This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7d98f60144 [Unity][Dlight] Fix matmul schedule when out_dtype = fp32 
and bias add is fp32 (#15363)
7d98f60144 is described below

commit 7d98f601442c9dd1ad02355704a0ac3ce09d35db
Author: masahi <[email protected]>
AuthorDate: Thu Jul 20 16:39:27 2023 +0900

    [Unity][Dlight] Fix matmul schedule when out_dtype = fp32 and bias add is 
fp32 (#15363)
    
    * workaround for unscheduled cast block
    
    * clean test
    
    * comment
---
 python/tvm/dlight/gpu/matmul.py        |  17 +++++
 tests/python/dlight/test_gpu_matmul.py | 129 +++++++++++++++++++++++++++++++++
 2 files changed, 146 insertions(+)

diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index b9977d08b9..def13a60ac 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -365,5 +365,22 @@ class Matmul(ScheduleRule):
         auto_inline_producers(sch, a_g2s)
         auto_inline_producers(sch, b_g2s)
         auto_inline_consumers(sch, l2g)
+
+        remaining_consumers = sch.get_consumers(l2g)
+
+        if len(remaining_consumers) != 0:
+            # Some blocks have failed to be inlined to the producer 
cache-write stage.
+            # This could be due to another producer block that has not been 
scheduled.
+            for c in remaining_consumers:
+                for p in sch.get_producers(c):
+                    if sch.get(p) != sch.get(l2g):
+                        sch.compute_inline(p)
+
+            # Try inlining into the cache-write stage again, this time it 
should succeed.
+            auto_inline_consumers(sch, l2g)
+
+        msg = "There are some consumers of the cache-write stage that are not 
properly inlined."
+        assert len(sch.get_consumers(l2g)) == 0, msg
+
         sch.decompose_reduction(main_block, ko)
         return sch
diff --git a/tests/python/dlight/test_gpu_matmul.py 
b/tests/python/dlight/test_gpu_matmul.py
index 318a3e833c..b9ee95b76b 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -19,6 +19,7 @@ import pytest
 
 import tvm.testing
 from tvm import dlight as dl
+from tvm.script import ir as I
 from tvm.script import tir as T
 from tvm.target import Target
 
@@ -247,5 +248,133 @@ class TestSkipGEMV(BaseBeforeAfter):
     expected = before
 
 
+class TestOutputFP32(BaseBeforeAfter):
+    # fmt: off
+
+    @T.prim_func
+    def before(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: 
T.Buffer((T.int64(4096), T.int64(128)), "float16"), p_lv48: T.handle, lv13_1: 
T.Buffer((T.int64(4096),), "float16"), p_lv3: T.handle, p_output0: T.handle):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        n = T.int64()
+        lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(4096)), 
"float16")
+        lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16")
+        p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, 
T.int64(4096)), "float16")
+        # with T.block("root"):
+        p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), 
T.int64(4096)), "float16")
+        var_matmul_intermediate = T.alloc_buffer((T.int64(1), n, 
T.int64(4096)))
+        var_compute_intermediate = T.alloc_buffer((T.int64(4096),))
+        var_T_add_intermediate = T.alloc_buffer((T.int64(1), n, T.int64(4096)))
+        var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), n, 
T.int64(4096)), "float16")
+        for i, j in T.grid(T.int64(4096), T.int64(4096)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(lv13[v_i, v_j // T.int64(8)], lv14[v_i, v_j // 
T.int64(32)])
+                T.writes(p_output0_intermediate_1[v_i, v_j])
+                p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv13[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j 
% T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv14[v_i, v_j // 
T.int64(32)]
+        for i0, i1, i2, k in T.grid(T.int64(1), n, T.int64(4096), 
T.int64(4096)):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(lv48[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, 
v_i2])
+                T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
+                var_matmul_intermediate[v_i0, v_i1, v_i2] = 
var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv48[v_i0, v_i1, 
v_k]) * T.Cast("float32", p_output0_intermediate_1[v_k, v_i2])
+        for i0 in range(T.int64(4096)):
+            with T.block("compute"):
+                v_i0 = T.axis.spatial(T.int64(4096), i0)
+                T.reads(lv13_1[v_i0])
+                T.writes(var_compute_intermediate[v_i0])
+                var_compute_intermediate[v_i0] = T.Cast("float32", 
lv13_1[v_i0])
+        for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], 
var_compute_intermediate[v_ax2])
+                T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
+                var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = 
var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + var_compute_intermediate[v_ax2]
+        for i0, i1, i2 in T.grid(T.int64(1), n, T.int64(4096)):
+            with T.block("compute_1"):
+                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2])
+                T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2])
+                var_compute_intermediate_1[v_i0, v_i1, v_i2] = 
T.Cast("float16", var_T_add_intermediate[v_i0, v_i1, v_i2])
+        for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(4096)):
+            with T.block("T_add_1"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], 
lv3[v_ax0, v_ax1, v_ax2])
+                T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
+                p_output0_intermediate[v_ax0, v_ax1, v_ax2] = 
var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv3[v_ax0, v_ax1, v_ax2]
+
+    @T.prim_func
+    def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), 
lv14: T.Buffer((T.int64(4096), T.int64(128)), "float16"), p_lv48: T.handle, 
lv13_1: T.Buffer((T.int64(4096),), "float16"), p_lv3: T.handle, p_output0: 
T.handle):
+        T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+        n = T.int64()
+        lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(4096)), 
"float16")
+        lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16")
+        p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), n, 
T.int64(4096)), "float16")
+        # with T.block("root"):
+        var_matmul_intermediate_reindex_pad_local = 
T.alloc_buffer((T.int64(1), (n + T.int64(31)) // T.int64(32) * T.int64(32), 
T.int64(4096)), scope="local")
+        lv48_reindex_pad_shared = T.alloc_buffer((T.int64(1), (n + 
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", 
scope="shared")
+        p_output0_intermediate_1_reindex_shared = T.alloc_buffer((T.int64(1), 
T.int64(4096), T.int64(4096)), "float16", scope="shared")
+        for ax0_ax2_0_fused in T.thread_binding(T.int64(64), 
thread="blockIdx.y"):
+            for ax1_0 in T.thread_binding((n + T.int64(31)) // T.int64(32), 
thread="blockIdx.x"):
+                for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
+                    for ax1_1 in T.thread_binding(T.int64(1), 
thread="vthread.x"):
+                        for ax2_2 in T.thread_binding(T.int64(16), 
thread="threadIdx.y"):
+                            for ax1_2 in T.thread_binding(T.int64(8), 
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, 
"pragma_unroll_explicit": 1}):
+                                for ax2_3_init, ax1_3_init in 
T.grid(T.int64(4), T.int64(4)):
+                                    with T.block("matmul_init"):
+                                        v0 = T.axis.spatial(T.int64(1), 
T.int64(0))
+                                        v1 = T.axis.spatial((n + T.int64(31)) 
// T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 
* T.int64(4) + ax1_3_init)
+                                        v2 = T.axis.spatial(T.int64(4096), 
ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + 
ax2_3_init)
+                                        T.reads()
+                                        
T.writes(var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
+                                        
var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
+                                for ax3_0 in range(T.int64(256)):
+                                    for ax0_ax1_ax2_fused_0 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in 
range(T.int64(2)):
+                                                for ax0_ax1_ax2_fused_3 in 
T.vectorized(T.int64(2)):
+                                                    with 
T.block("lv48_reindex_pad_shared"):
+                                                        v0 = 
T.axis.spatial(T.int64(1), T.int64(0))
+                                                        v1 = T.axis.spatial((n 
+ T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + 
(ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + 
ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                        v2 = 
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                        T.reads(lv48[v0, v1, 
v2])
+                                                        
T.writes(lv48_reindex_pad_shared[v0, v1, v2])
+                                                        
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        
lv48_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < n, lv48[v0, v1, v2], 
T.float16(0))
+                                    for ax0_ax1_ax2_fused_0 in 
T.thread_binding(T.int64(16), thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in 
T.thread_binding(T.int64(8), thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in 
range(T.int64(4)):
+                                                for ax0_ax1_ax2_fused_3 in 
T.vectorized(T.int64(2)):
+                                                    with 
T.block("p_output0_intermediate_1_reindex_shared"):
+                                                        v0 = 
T.axis.spatial(T.int64(1), T.int64(0))
+                                                        v1 = 
T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + 
(ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + 
ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16))
+                                                        v2 = 
T.axis.spatial(T.int64(4096), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * 
T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * 
T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16))
+                                                        T.reads(lv13[v2, v1 // 
T.int64(8)], lv14[v2, v1 // T.int64(32)])
+                                                        
T.writes(p_output0_intermediate_1_reindex_shared[v0, v1, v2])
+                                                        
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        
p_output0_intermediate_1_reindex_shared[v0, v1, v2] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv13[v2, v1 // T.int64(8)], T.Cast("uint32", v1 % 
T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv14[v2, v1 // 
T.int64(32)]
+                                    for ax3_1, ax2_3, ax1_3 in 
T.grid(T.int64(16), T.int64(4), T.int64(4)):
+                                        with T.block("matmul_update"):
+                                            v0 = T.axis.spatial(T.int64(1), 
T.int64(0))
+                                            v1 = T.axis.spatial((n + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_1 * 
T.int64(32) + ax1_2 * T.int64(4) + ax1_3)
+                                            v2 = T.axis.spatial(T.int64(4096), 
ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + 
ax2_3)
+                                            v3 = T.axis.reduce(T.int64(4096), 
ax3_0 * T.int64(16) + ax3_1)
+                                            
T.reads(var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2], 
lv48_reindex_pad_shared[T.int64(0), v1, v3], 
p_output0_intermediate_1_reindex_shared[T.int64(0), v2, v3])
+                                            
T.writes(var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2])
+                                            
var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] = 
var_matmul_intermediate_reindex_pad_local[T.int64(0), v1, v2] + 
T.Cast("float32", lv48_reindex_pad_shared[T.int64(0), v1, v3]) * 
T.Cast("float32", p_output0_intermediate_1_reindex_shared[T.int64(0), v2, v3])
+                                for ax0, ax1, ax2_0 in T.grid(T.int64(1), 
T.int64(4), T.int64(2)):
+                                    for ax2_1_1 in T.vectorized(T.int64(2)):
+                                        with 
T.block("var_matmul_intermediate_reindex_pad_local"):
+                                            v0 = T.axis.spatial(T.int64(1), 
ax0)
+                                            v1 = T.axis.spatial((n + 
T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * 
T.int64(4) + ax1)
+                                            v2 = T.axis.spatial(T.int64(4096), 
ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + 
ax2_1_1)
+                                            
T.reads(var_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv13_1[v2], 
lv3[T.int64(0), v1, v2])
+                                            
T.writes(p_output0_intermediate[T.int64(0), v1, v2])
+                                            if v1 < n:
+                                                
p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16", 
var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32", 
lv13_1[v2])) + lv3[T.int64(0), v1, v2]
+
+    # fmt: on
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to