This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2032b44f42 [TOPI] Ensure vectorization of input padding in `arm_cpu` 
int8 conv2d interleaved schedule (#15710)
2032b44f42 is described below

commit 2032b44f427dec2028e0d2a98be1d5210ed3033b
Author: Andrei Hutu <[email protected]>
AuthorDate: Mon Sep 11 10:15:08 2023 +0100

    [TOPI] Ensure vectorization of input padding in `arm_cpu` int8 conv2d 
interleaved schedule (#15710)
    
    When padding the input data, the int8 conv2d interleaved schedule tries to 
split the `data_im2col` cols axis by a factor of 16 in order to then vectorize 
over those splits. However, the size of the axis is `n_size = KH x KW x IC` and 
the `Legalize` pass only pads the number of input channels up to a multiple of 
8. Therefore, `n_size` is only guaranteed to be a multiple of 8, not 16.
    I modified the schedule to check whether a split factor of 16 is 
appropriate, otherwise use 8 instead, in order to ensure vectorization is 
performed in all cases.
---
 python/tvm/topi/arm_cpu/conv2d_gemm.py | 7 ++++++-
 1 file changed, 6 insertions(+), 1 deletion(-)

diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py 
b/python/tvm/topi/arm_cpu/conv2d_gemm.py
index ea9026688e..6ef8efec9e 100644
--- a/python/tvm/topi/arm_cpu/conv2d_gemm.py
+++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py
@@ -326,7 +326,12 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, 
final_out):
 
     b, m, n = data_im2col.op.axis
     if data_im2col.op.name == "data_im2col":
-        n_outer, n_inner = s[data_im2col].split(n, 16)
+        n_size = data_im2col.shape[2]
+        if n_size % 16 == 0:
+            split_factor = 16
+        else:
+            split_factor = 8
+        n_outer, n_inner = s[data_im2col].split(n, split_factor)
         s[data_im2col].unroll(n_outer)
         s[data_im2col].vectorize(n_inner)
         b_m_fused = s[data_im2col].fuse(b, m)

Reply via email to