mkatanbaf commented on code in PR #13752:
URL: https://github.com/apache/tvm/pull/13752#discussion_r1106334576


##########
python/tvm/topi/arm_cpu/qnn.py:
##########
@@ -118,38 +133,89 @@ def _make_tscript_ptr(buffer, offset, length, 
dtype="int16"):
     )
 
 
+def _bias_ptr(bias, c):
+    return _make_tscript_ptr(bias, c, 1, dtype="int32")
+
+
+def _scale_ptr(scale, c):
+    return _make_tscript_ptr(scale, c, 1, dtype="int32")
+
+
 def _make_tscript_call(func_name, *args):
     return T.evaluate(T.call_extern(func_name, *args, dtype="int32"))
 
 
 def _make_conv2d_primfunc(
-    call_dimensions: Tuple,
-    buffer_shapes: Tuple[Tuple, Tuple, Tuple, Tuple, Tuple],
+    output_dimensions: Tuple[int, int, int, int],
+    buffer_shapes: Tuple,
     aligned_func: Tuple[str, str],
     offset_func: Tuple[str, str],
-    ptr_gens: Tuple,
-):
-    height, width, out_channels = call_dimensions
+    ptr_gens: Tuple[Callable, Callable],
+    output_layout: str = "NHWC",
+) -> tir.function.PrimFunc:
+    """Makes a TIR PrimFunc computing Conv2D using a call to tensordot.
+
+    Can be used to generate regular, depthwise, and grouped Conv2D operators 
by passing different
+    arguments and ptr_gen functions. However, it only works for Conv2D 
operators where the height
+    stride of the tensor is divisible by two.
+
+    Parameters
+    ----------
+    output_dimensions : Tuple[int, int, int, int]
+        A tuple containing the out_height, out_width, out_channels, and 
desired num_outputs values
+        in that order.
+
+    buffer_shapes: Tuple[tvm.ir.container.Array]
+        The shapes of the data, kernel, bias, scale, and output tensors, in 
that order. Each shape
+        should be a TVM Array.
+
+    aligned_func: Tuple[str, str]
+        A tuple containing the (name, C implementation) of a word-aligned 
tensordot operator.
+
+    offset_func: Tuple[str, str]
+        A tuple containing the (name, C implementation) of a word-unaligned 
tensordot operator. Can
+        be a tuple of empty strings if the Conv2D in question does not need an 
unaligned operator.
+
+    ptr_gens: Tuple[Callable, Callable]
+        A tuple of two functions to generate data and kernel access pointers. 
They should take as
+        inputs the buffer, (y, x, c) indices, and an alignment offset. They 
should return a
+        T.tvm_access_ptr object which can be used in T.call_extern.
+
+    output_layout: str
+        The tensor layout that will be prosued by the generated PrimFunc. 
Should be NHWC or NCHW.
+    """
+
+    out_height, out_width, out_channels, num_outputs = output_dimensions
     data_shape, kernel_shape, bias_shape, scale_shape, output_shape = 
buffer_shapes
     aligned_func_name, aligned_func_code = aligned_func
     offset_func_name, offset_func_code = offset_func
-    output_ptr, data_ptr, kernel_ptr = ptr_gens
+    data_ptr, kernel_ptr = ptr_gens
 
     # If the functions are identical, we can skip the second loop
     if aligned_func_name == offset_func_name:
         aligned_channels = out_channels
-        offset_channels = tvm.tir.const(0)
-        c_step = tvm.tir.const(1)
+        offset_channels = 0
+        c_step = const(1)
     else:
         aligned_channels = out_channels // 2
         offset_channels = out_channels // 2
-        c_step = tvm.tir.const(2)
-
-    def bias_ptr(bias, c):
-        return _make_tscript_ptr(bias, c, 1, dtype="int32")
-
-    def scale_ptr(scale, c):
-        return _make_tscript_ptr(scale, c, 1, dtype="int32")
+        c_step = const(2)

Review Comment:
   Do we need to consider cases where the `out_channels` is an odd number?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to