nverke commented on code in PR #13110:
URL: https://github.com/apache/tvm/pull/13110#discussion_r1044960977


##########
tests/python/contrib/test_hexagon/test_async_dma_pipeline.py:
##########
@@ -349,5 +449,313 @@ def test_loading_vtcm_for_vrmpy(
                 "async_dma_input": async_input_runtime,
                 "async_dma_output": async_output_runtime,
                 "async_dma_input_output": async_input_output_runtime,
+                "async_dma_multi_input_output": 
async_multi_input_output_runtime,
+                "async_input_output_runtime_larger_buffers": 
async_input_output_runtime_larger_buffers,
             },
         )
+
+
+# from tvm.script import tir as T
[email protected]_module
+class ModulePipelined:
+    @T.prim_func
+    def main(
+        p0: T.Buffer[(1, 1, 230, 230, 4), "uint8"],
+        p1: T.Buffer[(2, 1, 7, 7, 1, 32, 4), "int8"],
+        T_cast: T.Buffer[(1, 2, 112, 112, 32), "int32"],
+    ) -> None:
+        # function attr dict
+        T.func_attr({"tir.noalias": True, "global_symbol": "main"})
+        # body
+        # with T.block("root")
+        conv2d_NCHWc_int8 = T.alloc_buffer([1, 2, 112, 112, 32], 
dtype="int32", scope="global.vtcm")
+        p0_global_vtcm = T.alloc_buffer([1, 1, 230, 230, 4], dtype="uint8", 
scope="global.vtcm")
+        p1_global_vtcm = T.alloc_buffer([2, 1, 7, 7, 1, 32, 4], dtype="int8", 
scope="global.vtcm")
+        for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(2, 1, 7, 7, 1, 32, 4):
+            with T.block("p1_global.vtcm"):
+                v0, v1, v2, v3, v4, v5, v6 = T.axis.remap(
+                    "SSSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5, ax6]
+                )
+                T.reads(p1[v0, v1, v2, v3, v4, v5, v6])
+                T.writes(p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6])
+                p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = p1[v0, v1, v2, 
v3, v4, v5, v6]
+        for po in T.serial(4):
+            for i in T.serial(55876):
+                with T.block("p0_global.vtcm"):
+                    v0 = T.axis.spatial(1, 0)
+                    v1 = T.axis.spatial(1, 0)
+                    v2 = T.axis.spatial(230, po * 56 + i // 916)
+                    v3 = T.axis.spatial(230, i % 916 // 4)
+                    v4 = T.axis.spatial(4, i % 4)
+                    T.reads(p0[v0, v1, v2, v3, v4])
+                    T.writes(p0_global_vtcm[v0, v1, v2, v3, v4])
+                    p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4]
+            for i in T.parallel(28):
+                for ii, iii, iiii in T.grid(2, 14, 8):
+                    with T.block("conv2d_NCHWc_int8_o_init"):
+                        n = T.axis.spatial(1, 0)
+                        oc_chunk = T.axis.spatial(2, ii)
+                        oh = T.axis.spatial(112, (po * 28 + i) // 14 * 14 + 
iii)
+                        ow = T.axis.spatial(112, (po * 28 + i) % 14 * 8 + iiii)
+                        oc_block_o = T.axis.spatial(1, 0)
+                        T.reads()
+                        T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32])
+                        for i4_1 in T.vectorized(32):
+                            with T.block("conv2d_NCHWc_int8_init"):
+                                oc_block_i_init = T.axis.spatial(32, i4_1)
+                                T.reads()
+                                T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, 
ow, oc_block_i_init])
+                                conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 
oc_block_i_init] = 0
+                for i1_1, i5_1, i6_1, i2_2, i3_2 in T.grid(2, 7, 7, 14, 8):
+                    with T.block("conv2d_NCHWc_int8_o_update"):
+                        n = T.axis.spatial(1, 0)
+                        oc_chunk = T.axis.spatial(2, i1_1)
+                        oh = T.axis.spatial(112, (po * 28 + i) // 14 * 14 + 
i2_2)
+                        ow = T.axis.spatial(112, (po * 28 + i) % 14 * 8 + i3_2)
+                        oc_block_o = T.axis.spatial(1, 0)
+                        kh = T.axis.reduce(7, i5_1)
+                        kw = T.axis.reduce(7, i6_1)
+                        ic_outer = T.axis.reduce(1, 0)
+                        ic_f_inner = T.axis.reduce(1, 0)
+                        ic_s_inner_o = T.axis.reduce(1, 0)
+                        T.reads(
+                            conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32],
+                            p0_global_vtcm[
+                                n,
+                                ic_outer,
+                                oh * 2 + kh,
+                                ow * 2 + kw,
+                                ic_f_inner * 4 : ic_f_inner * 4 + 4,
+                            ],
+                            p1_global_vtcm[oc_chunk, ic_outer, kh, kw, 
ic_f_inner, 0:32, 0:4],
+                        )
+                        T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32])
+                        A = T.match_buffer(
+                            p0_global_vtcm[
+                                n,
+                                ic_outer,
+                                oh * 2 + kh,
+                                ow * 2 + kw,
+                                ic_f_inner * 4 : ic_f_inner * 4 + 4,
+                            ],
+                            [4],
+                            dtype="uint8",
+                            offset_factor=1,
+                            scope="global.vtcm",
+                        )
+                        B = T.match_buffer(
+                            p1_global_vtcm[oc_chunk, ic_outer, kh, kw, 
ic_f_inner, 0:32, 0:4],
+                            [32, 4],
+                            dtype="int8",
+                            offset_factor=1,
+                            scope="global.vtcm",
+                        )
+                        C = T.match_buffer(
+                            conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32],
+                            [32],
+                            dtype="int32",
+                            offset_factor=1,
+                            scope="global.vtcm",
+                        )
+                        A_u8x4: T.uint8x4 = A[0:4]
+                        A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32")
+                        B_i8x128 = B[0, 0:128]
+                        B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, 
dtype="int32x32")
+                        C[0:32] = T.call_llvm_pure_intrin(
+                            4217,

Review Comment:
   Ahh interesting thought that they were tied to each intrin, will update 
accordingly! 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to