junrushao commented on code in PR #15141:
URL: https://github.com/apache/tvm/pull/15141#discussion_r1240091398


##########
tests/python/dlight/test_modules.py:
##########
@@ -0,0 +1,2621 @@
+from tvm.script import ir as I
+from tvm.script import tir as T
+
+# fmt: off
+
[email protected]_module
+class Decode:
+
+    @T.prim_func
+    def func1(A: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), B: 
T.Buffer((T.int64(80), T.int64(2560)), "float16"), T_transpose: 
T.Buffer((T.int64(2560), T.int64(2560)), "float16")):
+        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        decode = T.alloc_buffer((T.int64(2560), T.int64(2560)), "float16")
+        for i, j in T.grid(T.int64(2560), T.int64(2560)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j])
+                T.writes(decode[v_i, v_j])
+                decode[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % 
T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // 
T.int64(32), v_j]
+        for ax0, ax1 in T.grid(T.int64(2560), T.int64(2560)):
+            with T.block("T_transpose"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(decode[v_ax1, v_ax0])
+                T.writes(T_transpose[v_ax0, v_ax1])
+                T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
+
+    @T.prim_func
+    def func2(A: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), B: 
T.Buffer((T.int64(80), T.int64(10240)), "float16"), T_transpose: 
T.Buffer((T.int64(10240), T.int64(2560)), "float16")):
+        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        decode = T.alloc_buffer((T.int64(2560), T.int64(10240)), "float16")
+        for i, j in T.grid(T.int64(2560), T.int64(10240)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j])
+                T.writes(decode[v_i, v_j])
+                decode[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % 
T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // 
T.int64(32), v_j]
+        for ax0, ax1 in T.grid(T.int64(10240), T.int64(2560)):
+            with T.block("T_transpose"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(decode[v_ax1, v_ax0])
+                T.writes(T_transpose[v_ax0, v_ax1])
+                T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
+
+    @T.prim_func
+    def func3(A: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), B: 
T.Buffer((T.int64(320), T.int64(2560)), "float16"), T_transpose: 
T.Buffer((T.int64(2560), T.int64(10240)), "float16")):
+        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        decode = T.alloc_buffer((T.int64(10240), T.int64(2560)), "float16")
+        for i, j in T.grid(T.int64(10240), T.int64(2560)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j])
+                T.writes(decode[v_i, v_j])
+                decode[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % 
T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // 
T.int64(32), v_j]
+        for ax0, ax1 in T.grid(T.int64(2560), T.int64(10240)):
+            with T.block("T_transpose"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(decode[v_ax1, v_ax0])
+                T.writes(T_transpose[v_ax0, v_ax1])
+                T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
+
+    @T.prim_func
+    def func4(A: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), B: 
T.Buffer((T.int64(128), T.int64(4096)), "float16"), T_transpose: 
T.Buffer((T.int64(4096), T.int64(4096)), "float16")):
+        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        decode = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16")
+        for i, j in T.grid(T.int64(4096), T.int64(4096)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j])
+                T.writes(decode[v_i, v_j])
+                decode[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % 
T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // 
T.int64(32), v_j]
+        for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)):
+            with T.block("T_transpose"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(decode[v_ax1, v_ax0])
+                T.writes(T_transpose[v_ax0, v_ax1])
+                T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
+
+    @T.prim_func
+    def func5(A: T.Buffer((T.int64(512), T.int64(11008)), "uint32"), B: 
T.Buffer((T.int64(128), T.int64(11008)), "float16"), T_transpose: 
T.Buffer((T.int64(11008), T.int64(4096)), "float16")):
+        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        decode = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16")
+        for i, j in T.grid(T.int64(4096), T.int64(11008)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j])
+                T.writes(decode[v_i, v_j])
+                decode[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % 
T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // 
T.int64(32), v_j]
+        for ax0, ax1 in T.grid(T.int64(11008), T.int64(4096)):
+            with T.block("T_transpose"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(decode[v_ax1, v_ax0])
+                T.writes(T_transpose[v_ax0, v_ax1])
+                T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
+
+    @T.prim_func
+    def func6(A: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), B: 
T.Buffer((T.int64(344), T.int64(4096)), "float16"), T_transpose: 
T.Buffer((T.int64(4096), T.int64(11008)), "float16")):
+        T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        decode = T.alloc_buffer((T.int64(11008), T.int64(4096)), "float16")
+        for i, j in T.grid(T.int64(11008), T.int64(4096)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(A[v_i // T.int64(8), v_j], B[v_i // T.int64(32), v_j])
+                T.writes(decode[v_i, v_j])
+                decode[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(A[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % 
T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * B[v_i // 
T.int64(32), v_j]
+        for ax0, ax1 in T.grid(T.int64(4096), T.int64(11008)):
+            with T.block("T_transpose"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(decode[v_ax1, v_ax0])
+                T.writes(T_transpose[v_ax0, v_ax1])
+                T_transpose[v_ax0, v_ax1] = decode[v_ax1, v_ax0]
+
+
[email protected]_module
+class DecodeGemv:
+
+    @T.prim_func
+    def func1(lv2515: T.Buffer((T.int64(320), T.int64(50432)), "uint32"), 
lv2516: T.Buffer((T.int64(80), T.int64(50432)), "float32"), lv705: 
T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float32"), 
var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(50432)), 
"float32")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        var_decode_intermediate = T.alloc_buffer((T.int64(2560), 
T.int64(50432)))
+        for i, j in T.grid(T.int64(2560), T.int64(50432)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(lv2515[v_i // T.int64(8), v_j], lv2516[v_i // 
T.int64(32), v_j])
+                T.writes(var_decode_intermediate[v_i, v_j])
+                var_decode_intermediate[v_i, v_j] = T.Cast("float32", 
T.Cast("float16", T.bitwise_and(T.shift_right(lv2515[v_i // T.int64(8), v_j], 
T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - 
T.float16(7)) * lv2516[v_i // T.int64(32), v_j]
+        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(50432), 
T.int64(2560)):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(lv705[v_i0, v_i1, v_k], var_decode_intermediate[v_k, 
v_i2])
+                T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
+                var_matmul_intermediate[v_i0, v_i1, v_i2] = 
var_matmul_intermediate[v_i0, v_i1, v_i2] + lv705[v_i0, v_i1, v_k] * 
var_decode_intermediate[v_k, v_i2]
+
+    @T.prim_func
+    def func2(lv1363: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), 
lv1364: T.Buffer((T.int64(80), T.int64(2560)), "float16"), lv2067: 
T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), linear_bias192: 
T.Buffer((T.int64(2560),), "float16"), p_output0_intermediate: 
T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        var_decode_intermediate = T.alloc_buffer((T.int64(2560), 
T.int64(2560)), "float16")
+        var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(2560)), "float16")
+        for i, j in T.grid(T.int64(2560), T.int64(2560)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(lv1363[v_i // T.int64(8), v_j], lv1364[v_i // 
T.int64(32), v_j])
+                T.writes(var_decode_intermediate[v_i, v_j])
+                var_decode_intermediate[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv1363[v_i // T.int64(8), v_j], T.Cast("uint32", 
v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1364[v_i 
// T.int64(32), v_j]
+        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), 
T.int64(2560)):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(lv2067[v_i0, v_i1, v_k], var_decode_intermediate[v_k, 
v_i2])
+                T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
+                var_matmul_intermediate[v_i0, v_i1, v_i2] = 
var_matmul_intermediate[v_i0, v_i1, v_i2] + lv2067[v_i0, v_i1, v_k] * 
var_decode_intermediate[v_k, v_i2]
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], 
linear_bias192[v_ax2])
+                T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
+                p_output0_intermediate[v_ax0, v_ax1, v_ax2] = 
var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias192[v_ax2]
+
+    @T.prim_func
+    def func3(lv1381: T.Buffer((T.int64(320), T.int64(2560)), "uint32"), 
lv1382: T.Buffer((T.int64(80), T.int64(2560)), "float16"), lv328: 
T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), linear_bias195: 
T.Buffer((T.int64(2560),), "float16"), lv2062: T.Buffer((T.int64(1), 
T.int64(1), T.int64(2560)), "float16"), p_output0_intermediate: 
T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        var_decode_intermediate = T.alloc_buffer((T.int64(2560), 
T.int64(2560)), "float16")
+        var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(2560)), "float16")
+        var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(2560)), "float16")
+        for i, j in T.grid(T.int64(2560), T.int64(2560)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(lv1381[v_i // T.int64(8), v_j], lv1382[v_i // 
T.int64(32), v_j])
+                T.writes(var_decode_intermediate[v_i, v_j])
+                var_decode_intermediate[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv1381[v_i // T.int64(8), v_j], T.Cast("uint32", 
v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1382[v_i 
// T.int64(32), v_j]
+        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), 
T.int64(2560)):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(lv328[v_i0, v_i1, v_k], var_decode_intermediate[v_k, 
v_i2])
+                T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
+                var_matmul_intermediate[v_i0, v_i1, v_i2] = 
var_matmul_intermediate[v_i0, v_i1, v_i2] + lv328[v_i0, v_i1, v_k] * 
var_decode_intermediate[v_k, v_i2]
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], 
linear_bias195[v_ax2])
+                T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
+                var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = 
var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias195[v_ax2]
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
+            with T.block("T_add_1"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], 
lv2062[v_ax0, v_ax1, v_ax2])
+                T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
+                p_output0_intermediate[v_ax0, v_ax1, v_ax2] = 
var_T_add_intermediate[v_ax0, v_ax1, v_ax2] + lv2062[v_ax0, v_ax1, v_ax2]
+
+    @T.prim_func
+    def func4(lv1387: T.Buffer((T.int64(320), T.int64(10240)), "uint32"), 
lv1388: T.Buffer((T.int64(80), T.int64(10240)), "float16"), lv2115: 
T.Buffer((T.int64(1), T.int64(1), T.int64(2560)), "float16"), linear_bias196: 
T.Buffer((T.int64(10240),), "float32"), p_output0_intermediate: 
T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        var_decode_intermediate = T.alloc_buffer((T.int64(2560), 
T.int64(10240)), "float16")
+        var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(10240)))
+        var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(10240)))
+        T_multiply = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240)))
+        compute = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240)))
+        T_multiply_1 = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240)))
+        T_add = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(10240)))
+        var_T_multiply_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(10240)))
+        for i, j in T.grid(T.int64(2560), T.int64(10240)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(lv1387[v_i // T.int64(8), v_j], lv1388[v_i // 
T.int64(32), v_j])
+                T.writes(var_decode_intermediate[v_i, v_j])
+                var_decode_intermediate[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv1387[v_i // T.int64(8), v_j], T.Cast("uint32", 
v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1388[v_i 
// T.int64(32), v_j]
+        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(10240), 
T.int64(2560)):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(lv2115[v_i0, v_i1, v_k], var_decode_intermediate[v_k, 
v_i2])
+                T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
+                var_matmul_intermediate[v_i0, v_i1, v_i2] = 
var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv2115[v_i0, 
v_i1, v_k]) * T.Cast("float32", var_decode_intermediate[v_k, v_i2])
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], 
linear_bias196[v_ax2])
+                T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
+                var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = 
var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias196[v_ax2]
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)):
+            with T.block("T_multiply"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
+                T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
+                T_multiply[v_ax0, v_ax1, v_ax2] = 
var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T.float32(0.70710678118654757)
+        for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)):
+            with T.block("compute"):
+                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(T_multiply[v_i0, v_i1, v_i2])
+                T.writes(compute[v_i0, v_i1, v_i2])
+                compute[v_i0, v_i1, v_i2] = T.erf(T_multiply[v_i0, v_i1, v_i2])
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)):
+            with T.block("T_multiply_1"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(compute[v_ax0, v_ax1, v_ax2])
+                T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2])
+                T_multiply_1[v_ax0, v_ax1, v_ax2] = compute[v_ax0, v_ax1, 
v_ax2] * T.float32(0.5)
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)):
+            with T.block("T_add_1"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2])
+                T.writes(T_add[v_ax0, v_ax1, v_ax2])
+                T_add[v_ax0, v_ax1, v_ax2] = T.float32(0.5) + 
T_multiply_1[v_ax0, v_ax1, v_ax2]
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)):
+            with T.block("T_multiply_2"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_T_add_intermediate[v_ax0, v_ax1, v_ax2], 
T_add[v_ax0, v_ax1, v_ax2])
+                T.writes(var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2])
+                var_T_multiply_intermediate[v_ax0, v_ax1, v_ax2] = 
var_T_add_intermediate[v_ax0, v_ax1, v_ax2] * T_add[v_ax0, v_ax1, v_ax2]
+        for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(10240)):
+            with T.block("compute_1"):
+                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(var_T_multiply_intermediate[v_i0, v_i1, v_i2])
+                T.writes(p_output0_intermediate[v_i0, v_i1, v_i2])
+                p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", 
var_T_multiply_intermediate[v_i0, v_i1, v_i2])
+
+    @T.prim_func
+    def func5(lv1393: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), 
lv1394: T.Buffer((T.int64(320), T.int64(2560)), "float16"), lv2121: 
T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16"), linear_bias197: 
T.Buffer((T.int64(2560),), "float32"), lv329: T.Buffer((T.int64(1), T.int64(1), 
T.int64(2560)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), 
T.int64(1), T.int64(2560)), "float16")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        var_decode_intermediate = T.alloc_buffer((T.int64(10240), 
T.int64(2560)), "float16")
+        var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(2560)))
+        var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(2560)))
+        var_compute_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(2560)), "float16")
+        var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(2560)), "float16")
+        for i, j in T.grid(T.int64(10240), T.int64(2560)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(lv1393[v_i // T.int64(8), v_j], lv1394[v_i // 
T.int64(32), v_j])
+                T.writes(var_decode_intermediate[v_i, v_j])
+                var_decode_intermediate[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv1393[v_i // T.int64(8), v_j], T.Cast("uint32", 
v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv1394[v_i 
// T.int64(32), v_j]
+        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), 
T.int64(10240)):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(lv2121[v_i0, v_i1, v_k], var_decode_intermediate[v_k, 
v_i2])
+                T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
+                var_matmul_intermediate[v_i0, v_i1, v_i2] = 
var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv2121[v_i0, 
v_i1, v_k]) * T.Cast("float32", var_decode_intermediate[v_k, v_i2])
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], 
linear_bias197[v_ax2])
+                T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
+                var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = 
var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias197[v_ax2]
+        for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
+            with T.block("compute"):
+                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2])
+                T.writes(var_compute_intermediate[v_i0, v_i1, v_i2])
+                var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", 
var_T_add_intermediate[v_i0, v_i1, v_i2])
+        for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
+            with T.block("compute_1"):
+                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(var_compute_intermediate[v_i0, v_i1, v_i2])
+                T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2])
+                var_compute_intermediate_1[v_i0, v_i1, v_i2] = 
var_compute_intermediate[v_i0, v_i1, v_i2]
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
+            with T.block("T_add_1"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], 
lv329[v_ax0, v_ax1, v_ax2])
+                T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
+                p_output0_intermediate[v_ax0, v_ax1, v_ax2] = 
var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv329[v_ax0, v_ax1, v_ax2]
+
+    @T.prim_func
+    def func6(lv2509: T.Buffer((T.int64(1280), T.int64(2560)), "uint32"), 
lv2510: T.Buffer((T.int64(320), T.int64(2560)), "float16"), lv4105: 
T.Buffer((T.int64(1), T.int64(1), T.int64(10240)), "float16"), linear_bias383: 
T.Buffer((T.int64(2560),), "float32"), lv701: T.Buffer((T.int64(1), T.int64(1), 
T.int64(2560)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), 
T.int64(1), T.int64(2560)), "float32")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        var_decode_intermediate = T.alloc_buffer((T.int64(10240), 
T.int64(2560)), "float16")
+        var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(2560)))
+        var_T_add_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(2560)))
+        var_compute_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(2560)), "float16")
+        var_compute_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(2560)), "float16")
+        var_T_add_intermediate_1 = T.alloc_buffer((T.int64(1), T.int64(1), 
T.int64(2560)), "float16")
+        for i, j in T.grid(T.int64(10240), T.int64(2560)):
+            with T.block("decode"):
+                v_i, v_j = T.axis.remap("SS", [i, j])
+                T.reads(lv2509[v_i // T.int64(8), v_j], lv2510[v_i // 
T.int64(32), v_j])
+                T.writes(var_decode_intermediate[v_i, v_j])
+                var_decode_intermediate[v_i, v_j] = (T.Cast("float16", 
T.bitwise_and(T.shift_right(lv2509[v_i // T.int64(8), v_j], T.Cast("uint32", 
v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv2510[v_i 
// T.int64(32), v_j]
+        for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(2560), 
T.int64(10240)):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(lv4105[v_i0, v_i1, v_k], var_decode_intermediate[v_k, 
v_i2])
+                T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2])
+                with T.init():
+                    var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0)
+                var_matmul_intermediate[v_i0, v_i1, v_i2] = 
var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", lv4105[v_i0, 
v_i1, v_k]) * T.Cast("float32", var_decode_intermediate[v_k, v_i2])
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
+            with T.block("T_add"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_matmul_intermediate[v_ax0, v_ax1, v_ax2], 
linear_bias383[v_ax2])
+                T.writes(var_T_add_intermediate[v_ax0, v_ax1, v_ax2])
+                var_T_add_intermediate[v_ax0, v_ax1, v_ax2] = 
var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + linear_bias383[v_ax2]
+        for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
+            with T.block("compute"):
+                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(var_T_add_intermediate[v_i0, v_i1, v_i2])
+                T.writes(var_compute_intermediate[v_i0, v_i1, v_i2])
+                var_compute_intermediate[v_i0, v_i1, v_i2] = T.Cast("float16", 
var_T_add_intermediate[v_i0, v_i1, v_i2])
+        for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
+            with T.block("compute_1"):
+                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                T.reads(var_compute_intermediate[v_i0, v_i1, v_i2])
+                T.writes(var_compute_intermediate_1[v_i0, v_i1, v_i2])
+                var_compute_intermediate_1[v_i0, v_i1, v_i2] = 
var_compute_intermediate[v_i0, v_i1, v_i2]
+        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
+            with T.block("T_add_1"):
+                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                T.reads(var_compute_intermediate_1[v_ax0, v_ax1, v_ax2], 
lv701[v_ax0, v_ax1, v_ax2])
+                T.writes(var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2])
+                var_T_add_intermediate_1[v_ax0, v_ax1, v_ax2] = 
var_compute_intermediate_1[v_ax0, v_ax1, v_ax2] + lv701[v_ax0, v_ax1, v_ax2]
+        for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), T.int64(2560)):
+            with T.block("compute_2"):
+                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])

Review Comment:
   will do



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to