Wheest commented on a change in pull request #7050:
URL: https://github.com/apache/tvm/pull/7050#discussion_r538435168



##########
File path: python/tvm/topi/nn/conv2d_sparse.py
##########
@@ -0,0 +1,261 @@
+import tvm
+from tvm import te
+from tvm.topi import nn
+from tvm.topi.nn.util import get_pad_tuple
+from tvm.topi.util import get_const_tuple
+from tvm import autotvm
+from ..nn.conv2d import conv2d_infer_layout, _get_workload as 
_get_conv2d_workload
+from ..util import get_const_tuple, traverse_inline
+from tvm.topi.sparse import batch_csrmm, csrmm_default
+
+def _fallback_schedule(cfg, wkl):
+    HPAD, WPAD = wkl.hpad, wkl.wpad
+    HSTR, WSTR = wkl.hstride, wkl.wstride
+    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
+
+def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, 
is_depthwise=False,
+                        layout='NCHW'):
+    """
+    Get default schedule config for the workload
+    """
+    static_data_shape = []
+    for dim in get_const_tuple(data.shape):
+        if isinstance(dim, tvm.tir.Var):
+            static_data_shape.append(1)
+        else:
+            static_data_shape.append(dim)
+    data = te.placeholder(static_data_shape, dtype=data.dtype)
+    wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, 
layout)
+    is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
+    _fallback_schedule(cfg, wkl)
+
+def conv2d_sparse_gemm_nchw(data, w_data, w_indices, w_indptr,
+                            OC, KH, KW,
+                            strides, padding, dilation,
+                            out_dtype='float32'):
+    """Compute conv2d by transforming the input,
+    executing GEMM and not transforming the output back yet"""
+    batches, IC, IH, IW = get_const_tuple(data.shape)
+
+    K = KH * KW
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    dilated_kernel_h = (KH - 1) * dilation_h + 1
+    dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = \
+        get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, 
strides)
+
+    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+
+    N = OC
+    K = KH * KW * IC
+    M = OH * OW
+
+    if pad_top or pad_left:
+        data_pad = nn.pad(data, [0, 0, pad_top, pad_left], [0, 0, pad_down, 
pad_right],
+                          name="data_pad")
+    else:
+        data_pad = data
+
+    # --- Im2col
+
+    B_shape = (batches, K, M)
+    idxmod = tvm.tir.indexmod
+    idxdiv = tvm.tir.indexdiv
+    # print(KH, KW, IC, OW, HSTR)
+
+    B = te.compute(B_shape, lambda n, k, m:
+                   data_pad[n, (k // (KH*KW)) % IC,
+                            (k // KH) % KW + ((m // OW) * HSTR),
+                            (k % KW) + ((m % OW) * WSTR)],
+                       name='data_im2col')
+
+
+    # --- GEMM: A*B'
+    # oshape = (batches, N, M)
+    oshape = (batches, OC, OH, OW)
+    # B = te.compute((N,M), lambda n, m:
+    #                B[0, n, m],
+    #                name='data_flatten')
+    C = batch_csrmm(w_data, w_indices, w_indptr, B, oshape)
+    # C = csrmm_default(w_data, w_indices, w_indptr, B)
+
+
+    # placeholder reshape
+    # k = te.reduce_axis((0, K), 'k')
+    # C = te.compute(
+    #     oshape,
+    #     lambda b, c, h, w: te.sum(C[b, c, w] * C[b, c, w], axis=k),
+    #     name='C')
+
+    return C
+
+def csrdc(data, indices, indptr, inputs, oshape, kdim, strides, padding):

Review comment:
       CSR Direct Convolution, following the naming convention from `csrmm` 
(CSR Matrix Multiply) in `python/tvm/topi/sparse/csrmm.py`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to