lhutton1 commented on code in PR #15648: URL: https://github.com/apache/tvm/pull/15648#discussion_r1315667771
########## python/tvm/topi/arm_cpu/conv2d_gemm.py: ########## @@ -428,12 +435,29 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): b, m, n = data_im2col.op.axis if data_im2col.op.name == "data_im2col": - n_outer, n_inner = s[data_im2col].split(n, 16) + if A.op.name == "A_padded_K": + s[data_im2col].compute_at(s[A], A.op.axis[1]) + s[A].parallel(A.op.axis[1]) + elif A.op.name == "A_padded_M": + s[data_im2col].parallel(m) + s[A].parallel(A.op.axis[1]) + else: + s[data_im2col].parallel(m) + + split_factor = 16 + n_size = data_im2col.shape[2] + if n_size % split_factor != 0: + # Split by kernel area (KH * KW) to ensure proper vectorization + ic = data_im2col.op.input_tensors[0].shape[3] + split_factor = n_size // ic + + n_outer, n_inner = s[data_im2col].split(n, split_factor) s[data_im2col].unroll(n_outer) s[data_im2col].vectorize(n_inner) - s[data_im2col].parallel(m) elif padding_A: s[data_im2col].compute_inline() + _, n_inner = s[A].split(A.op.axis[2], 16) Review Comment: For the 16 here, let's make it a variable to maintain the link to https://github.com/apache/tvm/pull/15648/files#diff-42b1313a1be464c7f2c94f75d656be725f1ccb54b9391cd5b27c33009ac0e2d5R421 ########## python/tvm/topi/arm_cpu/conv2d_gemm.py: ########## @@ -326,7 +328,12 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): b, m, n = data_im2col.op.axis if data_im2col.op.name == "data_im2col": - n_outer, n_inner = s[data_im2col].split(n, 16) + n_size = data_im2col.shape[2] + if n_size % 16 == 0: + split_factor = 16 + else: + split_factor = 8 + n_outer, n_inner = s[data_im2col].split(n, split_factor) Review Comment: nit: as its a separate change, could it be split into a new PR? ########## python/tvm/topi/arm_cpu/conv2d_gemm.py: ########## @@ -428,12 +435,29 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): b, m, n = data_im2col.op.axis if data_im2col.op.name == "data_im2col": - n_outer, n_inner = s[data_im2col].split(n, 16) + if A.op.name == "A_padded_K": Review Comment: would be good to add a comment here stating that if there are both "padding_K" and "padding_M" ops, this path will be selected -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org