nverke commented on code in PR #13110:
URL: https://github.com/apache/tvm/pull/13110#discussion_r1044960977
##########
tests/python/contrib/test_hexagon/test_async_dma_pipeline.py:
##########
@@ -349,5 +449,313 @@ def test_loading_vtcm_for_vrmpy(
"async_dma_input": async_input_runtime,
"async_dma_output": async_output_runtime,
"async_dma_input_output": async_input_output_runtime,
+ "async_dma_multi_input_output":
async_multi_input_output_runtime,
+ "async_input_output_runtime_larger_buffers":
async_input_output_runtime_larger_buffers,
},
)
+
+
+# from tvm.script import tir as T
[email protected]_module
+class ModulePipelined:
+ @T.prim_func
+ def main(
+ p0: T.Buffer[(1, 1, 230, 230, 4), "uint8"],
+ p1: T.Buffer[(2, 1, 7, 7, 1, 32, 4), "int8"],
+ T_cast: T.Buffer[(1, 2, 112, 112, 32), "int32"],
+ ) -> None:
+ # function attr dict
+ T.func_attr({"tir.noalias": True, "global_symbol": "main"})
+ # body
+ # with T.block("root")
+ conv2d_NCHWc_int8 = T.alloc_buffer([1, 2, 112, 112, 32],
dtype="int32", scope="global.vtcm")
+ p0_global_vtcm = T.alloc_buffer([1, 1, 230, 230, 4], dtype="uint8",
scope="global.vtcm")
+ p1_global_vtcm = T.alloc_buffer([2, 1, 7, 7, 1, 32, 4], dtype="int8",
scope="global.vtcm")
+ for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(2, 1, 7, 7, 1, 32, 4):
+ with T.block("p1_global.vtcm"):
+ v0, v1, v2, v3, v4, v5, v6 = T.axis.remap(
+ "SSSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5, ax6]
+ )
+ T.reads(p1[v0, v1, v2, v3, v4, v5, v6])
+ T.writes(p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6])
+ p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = p1[v0, v1, v2,
v3, v4, v5, v6]
+ for po in T.serial(4):
+ for i in T.serial(55876):
+ with T.block("p0_global.vtcm"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(1, 0)
+ v2 = T.axis.spatial(230, po * 56 + i // 916)
+ v3 = T.axis.spatial(230, i % 916 // 4)
+ v4 = T.axis.spatial(4, i % 4)
+ T.reads(p0[v0, v1, v2, v3, v4])
+ T.writes(p0_global_vtcm[v0, v1, v2, v3, v4])
+ p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4]
+ for i in T.parallel(28):
+ for ii, iii, iiii in T.grid(2, 14, 8):
+ with T.block("conv2d_NCHWc_int8_o_init"):
+ n = T.axis.spatial(1, 0)
+ oc_chunk = T.axis.spatial(2, ii)
+ oh = T.axis.spatial(112, (po * 28 + i) // 14 * 14 +
iii)
+ ow = T.axis.spatial(112, (po * 28 + i) % 14 * 8 + iiii)
+ oc_block_o = T.axis.spatial(1, 0)
+ T.reads()
+ T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32])
+ for i4_1 in T.vectorized(32):
+ with T.block("conv2d_NCHWc_int8_init"):
+ oc_block_i_init = T.axis.spatial(32, i4_1)
+ T.reads()
+ T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh,
ow, oc_block_i_init])
+ conv2d_NCHWc_int8[n, oc_chunk, oh, ow,
oc_block_i_init] = 0
+ for i1_1, i5_1, i6_1, i2_2, i3_2 in T.grid(2, 7, 7, 14, 8):
+ with T.block("conv2d_NCHWc_int8_o_update"):
+ n = T.axis.spatial(1, 0)
+ oc_chunk = T.axis.spatial(2, i1_1)
+ oh = T.axis.spatial(112, (po * 28 + i) // 14 * 14 +
i2_2)
+ ow = T.axis.spatial(112, (po * 28 + i) % 14 * 8 + i3_2)
+ oc_block_o = T.axis.spatial(1, 0)
+ kh = T.axis.reduce(7, i5_1)
+ kw = T.axis.reduce(7, i6_1)
+ ic_outer = T.axis.reduce(1, 0)
+ ic_f_inner = T.axis.reduce(1, 0)
+ ic_s_inner_o = T.axis.reduce(1, 0)
+ T.reads(
+ conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32],
+ p0_global_vtcm[
+ n,
+ ic_outer,
+ oh * 2 + kh,
+ ow * 2 + kw,
+ ic_f_inner * 4 : ic_f_inner * 4 + 4,
+ ],
+ p1_global_vtcm[oc_chunk, ic_outer, kh, kw,
ic_f_inner, 0:32, 0:4],
+ )
+ T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32])
+ A = T.match_buffer(
+ p0_global_vtcm[
+ n,
+ ic_outer,
+ oh * 2 + kh,
+ ow * 2 + kw,
+ ic_f_inner * 4 : ic_f_inner * 4 + 4,
+ ],
+ [4],
+ dtype="uint8",
+ offset_factor=1,
+ scope="global.vtcm",
+ )
+ B = T.match_buffer(
+ p1_global_vtcm[oc_chunk, ic_outer, kh, kw,
ic_f_inner, 0:32, 0:4],
+ [32, 4],
+ dtype="int8",
+ offset_factor=1,
+ scope="global.vtcm",
+ )
+ C = T.match_buffer(
+ conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32],
+ [32],
+ dtype="int32",
+ offset_factor=1,
+ scope="global.vtcm",
+ )
+ A_u8x4: T.uint8x4 = A[0:4]
+ A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32")
+ B_i8x128 = B[0, 0:128]
+ B_i32x32: T.int32x32 = T.reinterpret(B_i8x128,
dtype="int32x32")
+ C[0:32] = T.call_llvm_pure_intrin(
+ 4217,
Review Comment:
Ahh interesting thought that they were tied to each intrin, will update
accordingly!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]