nverke commented on code in PR #13301:
URL: https://github.com/apache/tvm/pull/13301#discussion_r1021841923


##########
tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py:
##########
@@ -487,6 +487,149 @@ def lowered_single_reduction_loop_with_block_predicate(
                     )
 
 
[email protected]_func
+def single_reduction_loop_with_tensorize(
+    input_A: T.Buffer[(1, 64, 7, 7, 32), "uint8"],
+    input_B: T.Buffer[(16, 64, 1, 1, 8, 32, 4), "int8"],
+    output: T.Buffer[(1, 16, 7, 7, 32), "int32"],
+) -> None:
+    # body
+    # with T.block("root")
+    for i1, i2, i3, i4, i5 in T.grid(16, 4, 98, 2, 32):
+        with T.block("compute_o"):
+            n = T.axis.spatial(1, 0)
+            oc_chunk = T.axis.spatial(16, i1)
+            oh = T.axis.spatial(7, (i2 * 6272 + i3 * 64 + i4 * 32 + i5) // 
3584)
+            ow = T.axis.spatial(7, (i2 * 6272 + i3 * 64 + i4 * 32 + i5) % 3584 
// 512)
+            kh = T.axis.reduce(1, 0)
+            kw = T.axis.reduce(1, 0)
+            ic_outer = T.axis.reduce(64, (i2 * 6272 + i3 * 64 + i4 * 32 + i5) 
% 512 // 8)
+            ic_f_inner = T.axis.reduce(8, (i2 * 6272 + i3 * 64 + i4 * 32 + i5) 
% 8)
+            T.reads(
+                input_A[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : 
ic_f_inner * 4 + 4],
+                input_B[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:32, 0:4],
+            )
+            T.writes(output[n, oc_chunk, oh, ow, 0:32])
+            with T.init():
+                for x in T.serial(32):
+                    with T.block("compute_init"):
+                        oc_block_i_init = T.axis.spatial(32, x)
+                        T.reads()
+                        T.writes(output[n, oc_chunk, oh, ow, oc_block_i_init])
+                        output[n, oc_chunk, oh, ow, oc_block_i_init] = 0
+            with T.block("compute_o"):
+                T.reads(
+                    output[n, oc_chunk, oh, ow, 0:32],
+                    input_A[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : 
ic_f_inner * 4 + 4],
+                    input_B[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:32, 0:4],
+                )
+                T.writes(output[n, oc_chunk, oh, ow, 0:32])
+                A = T.match_buffer(
+                    input_A[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : 
ic_f_inner * 4 + 4],
+                    [4],
+                    dtype="uint8",
+                    offset_factor=1,
+                )
+                B = T.match_buffer(
+                    input_B[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:32, 0:4],
+                    [32, 4],
+                    dtype="int8",
+                    offset_factor=1,
+                )
+                C = T.match_buffer(
+                    output[n, oc_chunk, oh, ow, 0:32], [32], dtype="int32", 
offset_factor=1
+                )
+                A_u8x4: T.uint8x4 = A[0:4]
+                A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32")
+                B_i8x128 = B[0, 0:128]
+                B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, 
dtype="int32x32")
+                C[0:32] = T.call_llvm_pure_intrin(
+                    4217, T.uint32(3), C[0:32], T.broadcast(A_i32, 32), 
B_i32x32, dtype="int32x32"
+                )
+
+
[email protected]_func
+def nested_reduction_loop_with_inner_match_buffers(
+    in0: T.Buffer[(4, 16), "int8"],
+    in1: T.Buffer[(4, 16), "int8"],
+    out: T.Buffer[(4, 4), "int32"],
+) -> None:
+    # body
+    # with T.block("root")
+    for y in T.serial(4):
+        with T.block("C"):
+            yi = T.axis.spatial(4, y)
+            T.reads(in0[yi, 0:16], in1[yi, 0:16])
+            T.writes(out[yi, 0:4])
+            for x in T.serial(4):
+                xr = T.axis.reduce(4, x)
+                with T.init():
+                    for i in T.serial(4):
+                        with T.block("C_init"):
+                            ii = T.axis.spatial(4, i)
+                            T.reads()
+                            T.writes(out[yi, ii])
+                            out[yi, ii] = 0
+                with T.block("C"):
+                    T.reads(
+                        out[yi, xr],
+                        in0[yi, yi * 4 + xr : yi * 4 + xr + 4],
+                        in1[yi, yi * 4 + xr : yi * 4 + xr + 4],
+                    )
+                    T.writes(out[yi, xr])
+                    A = T.match_buffer(
+                        in0[yi, yi * 4 + xr : yi * 4 + xr + 4], [4], 
dtype="int8", offset_factor=1
+                    )
+                    B = T.match_buffer(
+                        in1[yi, yi * 4 + xr : yi * 4 + xr + 4], [4], 
dtype="int8", offset_factor=1
+                    )
+                    C = T.match_buffer(out[yi, xr], [1], dtype="int32", 
offset_factor=1)
+                    A_i8x4: T.int8x4 = A[0:4]
+                    A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
+                    B_i8x4: T.int8x4 = B[0:4]
+                    B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
+                    C[0] = A_i32 + B_i32 + C[0]
+
+
[email protected]_func
+def nested_reduction_loop_with_outer_match_buffers(
+    in0: T.Buffer[(4, 16), "int8"],
+    in1: T.Buffer[(4, 16), "int8"],
+    out: T.Buffer[(4, 4), "int32"],
+) -> None:
+    # body
+    # with T.block("root")
+    for y in T.serial(4):
+        with T.block("C"):
+            yi = T.axis.spatial(4, y)
+            T.reads(in0[yi, 0:16], in1[yi, 0:16])
+            T.writes(out[yi, 0:4])
+            A = T.match_buffer(in0[yi, 0:16], [16], dtype="int8", 
offset_factor=1)
+            B = T.match_buffer(in1[yi, 0:16], [16], dtype="int8", 
offset_factor=1)
+            C = T.match_buffer(out[yi, 0:4], [4], dtype="int32", 
offset_factor=1)
+            for x in T.serial(4):
+                xr = T.axis.reduce(4, x)
+                with T.init():
+                    for i in T.serial(4):
+                        with T.block("C_init"):
+                            ii = T.axis.spatial(4, i)
+                            T.reads()
+                            T.writes(out[yi, ii])
+                            out[yi, ii] = 0
+                with T.block("C"):

Review Comment:
   https://github.com/apache/tvm/pull/13373/files



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to