Wheest commented on a change in pull request #6137:
URL: https://github.com/apache/tvm/pull/6137#discussion_r591433973



##########
File path: topi/python/topi/arm_cpu/group_conv2d.py
##########
@@ -0,0 +1,310 @@
+import tvm
+from tvm import autotvm
+from tvm import te
+from ..util import get_const_tuple
+from ..nn.pad import pad
+from .. import tag
+
+from ..nn.conv2d import group_conv2d_nchw
+from ..nn.util import infer_pad
+from ..nn.conv2d import _get_workload as _get_conv2d_workload
+
+from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
+
+
+def group_conv2d_nchw(data, kernel, strides, padding, dilation, groups,
+                      out_dtype):
+    """Compute group_conv2d with NCHW layout"""
+    return group_conv2d_nchw_spatial_pack(data, kernel, strides, padding,
+                                          dilation, groups, out_dtype)
+
+
+def schedule_group_conv2d_nchw(outs):
+    """Compute group_conv2d with NCHW layout"""
+    return schedule_group_conv2d_nchwc(outs)
+
+
+def _get_default_config(cfg, data, kernel, strides, padding, groups, out_dtype,
+                        layout='NCHW'):
+    """
+    Get default schedule config for the workload
+    """
+    static_data_shape = []
+    for dim in get_const_tuple(data.shape):
+        if isinstance(dim, tvm.tir.Var):
+            static_data_shape.append(1)
+        else:
+            static_data_shape.append(dim)
+    data = te.placeholder(static_data_shape, dtype=data.dtype)
+
+    wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype,
+                               layout)
+    _fallback_schedule(cfg, wkl)
+
+
+def _fallback_schedule(cfg, wkl):
+    simd_width = 4 # assume ARM SIMD Width is 4
+    HPAD, WPAD = wkl.hpad, wkl.wpad
+    HSTR, WSTR = wkl.hstride, wkl.wstride
+    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+    G = wkl.groups
+    KPG = wkl.out_filter // G
+    CPG = wkl.in_filter // G
+    oc_bn = 1
+
+    for bn in range(simd_width, 0, -1):
+        if KPG % bn == 0:
+            oc_bn = bn
+            break
+
+    ic_bn = 1
+    for bn in range(oc_bn, 0, -1):
+        if CPG % bn == 0:
+            ic_bn = bn
+            break
+
+    reg_n = 1
+    for n in range(31, 0, -1):
+        if out_width % n == 0:
+            reg_n = n
+            break
+
+    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
+    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
+    cfg["tile_ow"] = SplitEntity([out_width // reg_n, reg_n])
+    cfg["unroll_kw"] = OtherOptionEntity(False)
+
+
[email protected]_topi_compute("group_conv2d_nchw.arm_cpu")
+def group_conv2d_nchw_spatial_pack(cfg, data, kernel, strides, padding,
+                                   dilation, groups, out_dtype='float32'):
+    assert isinstance(dilation, int) or len(dilation) == 2
+    if isinstance(dilation, int):
+        dilation_h, dilation_w = dilation, dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    assert isinstance(padding, int) or len(padding) == 2 or len(padding) == 4
+    if isinstance(padding, int):
+        HPAD, WPAD = padding, padding
+    elif len(padding) == 2:
+        HPAD, WPAD = padding
+    else:
+        HPAD, _, WPAD, _ = padding
+
+    assert isinstance(strides, int) or len(strides) == 2
+    if isinstance(strides, int):
+        HSTR, WSTR = strides, strides
+    else:
+        HSTR, WSTR = strides
+
+    N, CI, IH, IW = get_const_tuple(data.shape)
+    CO, CIG, KH, KW = get_const_tuple(kernel.shape)
+
+    pad_height = IH + 2 * HPAD
+    pad_width = IW + 2 * WPAD
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+    OH = (IH + 2 * HPAD - dilated_kernel_h) // HSTR + 1
+    OW = (IW + 2 * WPAD - dilated_kernel_w) // WSTR + 1
+
+    G = groups
+    KPG = CO // G
+    CPG = CI // G
+
+    cfg.define_split("tile_ic", CI, num_outputs=2)
+    cfg.define_split("tile_oc", CO, num_outputs=2)
+    cfg.define_split("tile_ow", OW, num_outputs=2, filter=lambda y: y.size[-1] 
<= 64)
+    cfg.define_knob("unroll_kw", [True, False])
+
+    # If no config was set, we can fallback to default config.
+    if cfg.is_fallback:
+        _get_default_config(cfg, te.placeholder((N, CI, IH, IW), 
dtype=data.dtype),
+                            te.placeholder((N, CI // G, KH, KW),
+                                           dtype=kernel.dtype),
+                            strides, padding, groups, out_dtype)
+
+    oc_bn = cfg['tile_oc'].size[-1]
+    ic_bn = cfg['tile_ic'].size[-1]
+    # pack data
+    DOPAD = (HPAD != 0 or WPAD != 0)
+    if DOPAD:
+        data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
+    else:
+        data_pad = data
+
+    shape = (G, N, CPG // ic_bn,
+             pad_height, ic_bn, pad_width)
+
+    data_vec = te.compute(shape,
+                          lambda g, n, C, h, c, w:
+                          data_pad[n, C * ic_bn + c + CPG * g, h, w],
+                          name='data_vec')
+
+    # pack kernel
+    shape = (G, KPG//oc_bn, CPG//ic_bn,
+             KH, KW, ic_bn, oc_bn)
+    kernel_vec = te.compute(shape,

Review comment:
       If this issue cannot be easily resolved, and the pull request be merged 
without `alter_op_layout`, and a new PR be made that tracks the issue, and 
includes my attempts thus far in solving it?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to