lhutton1 commented on code in PR #15648:
URL: https://github.com/apache/tvm/pull/15648#discussion_r1315667771


##########
python/tvm/topi/arm_cpu/conv2d_gemm.py:
##########
@@ -428,12 +435,29 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
 
     b, m, n = data_im2col.op.axis
     if data_im2col.op.name == "data_im2col":
-        n_outer, n_inner = s[data_im2col].split(n, 16)
+        if A.op.name == "A_padded_K":
+            s[data_im2col].compute_at(s[A], A.op.axis[1])
+            s[A].parallel(A.op.axis[1])
+        elif A.op.name == "A_padded_M":
+            s[data_im2col].parallel(m)
+            s[A].parallel(A.op.axis[1])
+        else:
+            s[data_im2col].parallel(m)
+
+        split_factor = 16
+        n_size = data_im2col.shape[2]
+        if n_size % split_factor != 0:
+            # Split by kernel area (KH * KW) to ensure proper vectorization
+            ic = data_im2col.op.input_tensors[0].shape[3]
+            split_factor = n_size // ic
+
+        n_outer, n_inner = s[data_im2col].split(n, split_factor)
         s[data_im2col].unroll(n_outer)
         s[data_im2col].vectorize(n_inner)
-        s[data_im2col].parallel(m)
     elif padding_A:
         s[data_im2col].compute_inline()
+        _, n_inner = s[A].split(A.op.axis[2], 16)

Review Comment:
   For the 16 here, let's make it a variable to maintain the link to 
https://github.com/apache/tvm/pull/15648/files#diff-42b1313a1be464c7f2c94f75d656be725f1ccb54b9391cd5b27c33009ac0e2d5R421



##########
python/tvm/topi/arm_cpu/conv2d_gemm.py:
##########
@@ -326,7 +328,12 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, 
final_out):
 
     b, m, n = data_im2col.op.axis
     if data_im2col.op.name == "data_im2col":
-        n_outer, n_inner = s[data_im2col].split(n, 16)
+        n_size = data_im2col.shape[2]
+        if n_size % 16 == 0:
+            split_factor = 16
+        else:
+            split_factor = 8
+        n_outer, n_inner = s[data_im2col].split(n, split_factor)

Review Comment:
   nit: as its a separate change, could it be split into a new PR?



##########
python/tvm/topi/arm_cpu/conv2d_gemm.py:
##########
@@ -428,12 +435,29 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out):
 
     b, m, n = data_im2col.op.axis
     if data_im2col.op.name == "data_im2col":
-        n_outer, n_inner = s[data_im2col].split(n, 16)
+        if A.op.name == "A_padded_K":

Review Comment:
   would be good to add a comment here stating that if there are both 
"padding_K" and "padding_M" ops, this path will be selected



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to