wsl-inspur commented on a change in pull request #5485:
URL: https://github.com/apache/incubator-tvm/pull/5485#discussion_r417902315



##########
File path: topi/python/topi/cuda/conv2d_nhwc_winograd.py
##########
@@ -0,0 +1,639 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument
+# pylint: disable=too-many-arguments,too-many-locals
+# pylint: disable=too-many-statements
+"""Winograd template for cuda backend"""
+
+import tvm
+from tvm import te
+from tvm import autotvm
+from .. import nn
+from ..util import get_const_int, get_const_tuple, traverse_inline
+from ..nn.winograd_util import winograd_transform_matrices
+from .tensor_intrin import intrin_wmma_load_matrix_A
+from .tensor_intrin import intrin_wmma_load_matrix_W
+from .tensor_intrin import intrin_wmma_store_matrix
+from .tensor_intrin import intrin_wmma_gemm
+
+def _infer_tile_size(data, kernel):
+    """Compute the tile size"""
+    N, H, W, CI = get_const_tuple(data.shape)
+    if H % 8 == 0:
+        return 4
+    return 2
+
+
+def schedule_bgemm_tensorcore(cfg, s, bgemm, data_pack, kernel_pack):
+    """Schedule for bgemm tensorcore"""
+    A = data_pack
+    B = kernel_pack
+    C = bgemm
+    _, _, P, out_dim = get_const_tuple(C.shape)
+    out_dtype = C.dtype
+
+    # Explicit memory access
+    AS = s.cache_read(A, 'shared', [C])
+    BS = s.cache_read(B, 'shared', [C])
+    AF = s.cache_read(AS, 'wmma.matrix_a', [C])
+    BF = s.cache_read(BS, 'wmma.matrix_b', [C])
+    CF = s.cache_write(C, 'wmma.accumulator')
+    CS = s.cache_read(CF, 'shared', [C])
+
+    # Create tuning space
+    cfg.define_knob("block_row_warps", [1, 2, 4])
+    cfg.define_knob("block_col_warps", [1, 2, 4])
+    cfg.define_knob("warp_row_tiles", [1, 2, 4, 8])
+    cfg.define_knob("warp_col_tiles", [1, 2, 4, 8])
+    cfg.define_knob("chunk", [1, 2, 4, 8])
+    cfg.define_knob("offset", [0, 1, 2, 4, 8])
+    cfg.define_knob("offsetCS", [0, 1, 2, 4, 8])
+    cfg.define_knob("vec", [1, 2, 4, 8])
+
+    # Ensure that the default parameters are applicable when autotvm is not in 
use
+    if (P % 16 == 0 and out_dim % 16 == 0):
+        cfg.define_knob("wmma_m", [16, 8, 32])
+    elif (P % 32 == 0 and out_dim % 8 == 0):
+        cfg.define_knob("wmma_m", [32, 16, 8])
+    elif (P % 8 == 0 and out_dim % 32 == 0):
+        cfg.define_knob("wmma_m", [8, 16, 32])
+
+    warp_size = 32
+    wmma_k = 16
+    block_row_warps = cfg["block_row_warps"].val
+    block_col_warps = cfg["block_col_warps"].val
+    warp_row_tiles = cfg["warp_row_tiles"].val
+    warp_col_tiles = cfg["warp_col_tiles"].val
+    chunk = cfg["chunk"].val
+    offsetAB = cfg["offset"].val
+    offsetCS = cfg["offsetCS"].val
+    wmma_m = cfg["wmma_m"].val
+    vec = cfg["vec"].val
+
+    if wmma_m == 16:
+        wmma_n = 16
+    elif wmma_m == 8:
+        wmma_n = 32
+    elif wmma_m == 32:
+        wmma_n = 8
+
+    # Define the stride of intrin functions
+    AS_align = chunk * wmma_k + offsetAB
+    BS_align = warp_col_tiles * block_col_warps * wmma_n + offsetAB
+    CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS
+    AS_stride = [AS_align, 1]
+    BS_stride = [BS_align, 1]
+    AF_stride = [wmma_k, 1]
+    BF_stride = [wmma_n * warp_col_tiles, 1]
+    CF_stride = [warp_col_tiles * wmma_n, 1]
+    CS_stride = [CS_align, 1]
+    block_x = te.thread_axis('blockIdx.x')
+    block_y = te.thread_axis('blockIdx.y')
+    block_z = te.thread_axis('blockIdx.z')
+    thread_x = te.thread_axis('threadIdx.x')
+    thread_y = te.thread_axis('threadIdx.y')
+    thread_z = te.thread_axis('threadIdx.z')
+
+    # Schedule for computation
+    block_factor_b = wmma_m * warp_row_tiles * block_row_warps
+    block_factor_o = wmma_n * warp_col_tiles * block_col_warps
+    alpha_1, alpha_2, b, o = C.op.axis
+    block_k = s[C].fuse(alpha_1, alpha_2)
+    block_i, bc = s[C].split(b, factor=block_factor_b)
+    block_j, oc = s[C].split(o, factor=block_factor_o)
+    s[C].reorder(block_k, block_i, block_j, bc, oc)
+    t = s[C].fuse(bc, oc)
+    t, vi = s[C].split(t, factor=vec)
+    t, tx = s[C].split(t, factor=warp_size)
+    t, ty = s[C].split(t, factor=block_row_warps)
+    t, tz = s[C].split(t, factor=block_col_warps)
+    s[C].bind(block_k, block_z)
+    s[C].bind(block_i, block_x)
+    s[C].bind(block_j, block_y)
+    s[C].bind(tz, thread_z)
+    s[C].bind(ty, thread_y)
+    s[C].bind(tx, thread_x)
+    s[C].vectorize(vi)
+
+    # Schedule for wmma store
+    s[CS].compute_at(s[C], block_j)
+    _, _, bb, oo = CS.op.axis
+    s[CS].storage_align(bb, CS_align - 1, CS_align)
+    bb, bbi = s[CS].split(bb, factor=wmma_m)
+    oo, ooi = s[CS].split(oo, factor=wmma_n)
+    bb, bbii = s[CS].split(bb, factor=warp_row_tiles)
+    oo, ooii = s[CS].split(oo, factor=warp_col_tiles)
+    s[CS].reorder(bb, oo, bbii, ooii, bbi, ooi)
+
+    # Schedule for wmma computation
+    s[CF].compute_at(s[CS], oo)
+    _, _, warp_i, warp_j = CF.op.axis
+    warp_i, _ii = s[CF].split(warp_i, factor=wmma_m)
+    warp_j, _jj = s[CF].split(warp_j, factor=wmma_n)
+    k, = CF.op.reduce_axis
+    k, _k = s[CF].split(k, factor=wmma_k)
+    ko, ki = s[CF].split(k, factor=chunk)
+    s[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k)
+
+    # Schedule for  wmma_matrix_a load
+    s[AF].compute_at(s[CF], ki)
+    _, _, b, i = AF.op.axis
+    b, b_ii = s[AF].split(b, factor=wmma_m)
+    i, i_jj = s[AF].split(i, factor=wmma_k)
+    s[AF].reorder(b, i, b_ii, i_jj)
+
+    # Schedule for  wmma_matrix_b load
+    s[BF].compute_at(s[CF], ki)
+    _, _, i, o = BF.op.axis
+    o, o_ii = s[BF].split(o, factor=wmma_n)
+    i, i_ii = s[BF].split(i, factor=wmma_k)
+    s[BF].reorder(i, o, i_ii, o_ii)
+
+    # Schedule for A's(B's) shared memory load
+    def shared_shedule(stage, strides):
+        s[stage].compute_at(s[CF], ko)
+        _, _, xo, yo = stage.op.axis
+        s[stage].storage_align(xo, strides - 1, strides)
+        t = s[stage].fuse(xo, yo)
+        t, vi = s[stage].split(t, factor=vec)
+        t, tx = s[stage].split(t, factor=warp_size)
+        t, ty = s[stage].split(t, factor=block_row_warps)
+        _, tz = s[stage].split(t, factor=block_col_warps)
+        s[stage].bind(ty, thread_y)
+        s[stage].bind(tz, thread_z)
+        s[stage].bind(tx, thread_x)
+        s[stage].vectorize(vi)
+
+    shared_shedule(AS, AS_align)
+    shared_shedule(BS, BS_align)
+
+    shape = (wmma_m, wmma_n, wmma_k)
+    in_dtype = 'float16'
+    AL_gemm = te.placeholder((wmma_m, wmma_k), name='AL_gemm', dtype=in_dtype)
+    BL_gemm = te.placeholder((wmma_k, wmma_n), name='BL_gemm', dtype=in_dtype)
+    k_gemm = te.reduce_axis((0, wmma_k), name='k_gemm')
+    CL_compute = te.compute((wmma_m, wmma_n), lambda ii, jj:
+                            te.sum(AL_gemm[ii, k_gemm].astype(out_dtype) *
+                                   BL_gemm[k_gemm, jj].astype(out_dtype),
+                                   axis=k_gemm), name='CL_compute')
+
+    # Lower the computation loops down to TensorCore hardware intrinsics
+    # by mapping the tensorcore to tensor intrinsics
+    s[AF].tensorize(b_ii, intrin_wmma_load_matrix_A(AF_stride, AS_stride, 
shape, "row_major",
+                                                    (wmma_m, wmma_k), (wmma_m, 
wmma_k), 'float16'))
+    s[BF].tensorize(i_ii, intrin_wmma_load_matrix_W(BF_stride, BS_stride, 
shape, "row_major",
+                                                    (wmma_k, wmma_n), (wmma_k, 
wmma_n), 'float16'))
+    s[CF].tensorize(_ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, 
AF_stride,
+                                          BF_stride, CF_stride, shape))
+    s[CS].tensorize(bbi, intrin_wmma_store_matrix(CS_stride, CF_stride, shape, 
out_dtype,
+                                                  (wmma_m, wmma_n), (wmma_m, 
wmma_n)))
+
+
+def schedule_bgemm_direct(cfg, s, bgemm, data_pack, kernel_pack):
+    """Schedule for bgemm direct"""
+    b1, b2, y, x = s[bgemm].op.axis
+    rc = s[bgemm].op.reduce_axis[0]
+    alpha = get_const_int(b1.dom.extent)
+
+    # Create tuning space
+    cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4,
+                     filter=lambda x: x.size[-3:] == [1, 1, 1])
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
+    cfg.define_split("tile_rc", rc, num_outputs=2)
+    cfg.define_knob("offset_bgemm", [0, 1, 2, 4, 8])
+    cfg.define_knob("vector_bgemm", [1, 2, 4, 8])
+    offset_bgemm = cfg["offset_bgemm"].val
+    vector_bgemm = cfg["vector_bgemm"].val
+
+    C = bgemm
+    A0, B0 = kernel_pack, data_pack
+
+    # Designate the memory hierarchy
+    OL = s.cache_write(C, 'local')
+    AA = s.cache_read(A0, 'shared', [OL])
+    BB = s.cache_read(B0, 'shared', [OL])
+
+    # Tile and bind spatial axes
+    b = s[bgemm].fuse(b1, b2)
+    bgemm_scope, b = s[bgemm].split(b, nparts=1)
+    bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, C, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
+    s[C].bind(bz, te.thread_axis("blockIdx.z"))
+    s[C].bind(by, te.thread_axis("blockIdx.y"))
+    s[C].bind(bx, te.thread_axis("blockIdx.x"))
+    s[C].bind(vz, te.thread_axis("vthread"))
+    s[C].bind(vy, te.thread_axis("vthread"))
+    s[C].bind(vx, te.thread_axis("vthread"))
+    s[C].bind(tz, te.thread_axis("threadIdx.z"))
+    s[C].bind(ty, te.thread_axis("threadIdx.y"))
+    s[C].bind(tx, te.thread_axis("threadIdx.x"))
+    s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)
+
+    # Tile reduction axes
+    s[OL].compute_at(s[C], tx)
+    b1, b2, y, x = s[OL].op.axis
+    b = s[OL].fuse(b1, b2)
+    rc, = s[OL].op.reduce_axis
+    rco, rci = cfg['tile_rc'].apply(s, OL, rc)
+    s[OL].reorder(rco, b, y, x, rci)
+
+    s[AA].compute_at(s[OL], rco)
+    _, _, k, n = s[AA].op.axis
+    AA_align = offset_bgemm + cfg["tile_x"].size[1] * cfg["tile_x"].size[2] * 
cfg["tile_x"].size[3]
+    s[AA].storage_align(k, AA_align - 1, AA_align)
+
+    s[BB].compute_at(s[OL], rco)
+    _, _, m, k = s[BB].op.axis
+    BB_align = offset_bgemm + cfg["tile_rc"].size[1]
+    s[BB].storage_align(m, BB_align - 1, BB_align)
+
+    # Schedule for A and B shared memory load
+    for load in [AA, BB]:
+        fused = s[load].fuse(*list(s[load].op.axis))
+        fused, ti = s[load].split(fused, factor=vector_bgemm)
+        fused, tx = s[load].split(fused, cfg["tile_x"].size[2])
+        fused, ty = s[load].split(fused, cfg["tile_y"].size[2])
+        fused, tz = s[load].split(fused, cfg["tile_b"].size[2])
+        s[load].bind(tz, te.thread_axis("threadIdx.z"))
+        s[load].bind(ty, te.thread_axis("threadIdx.y"))
+        s[load].bind(tx, te.thread_axis("threadIdx.x"))
+        s[load].vectorize(ti)
+
+
+def nhwc_winograd_cuda(cfg, data, kernel, strides, padding, dilation, 
out_dtype,
+                       use_tensorcore, pre_computed):
+    """Compute declaration for winograd"""
+    tile_size = _infer_tile_size(data, kernel)
+    N, H, W, CI = get_const_tuple(data.shape)
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides
+
+    if not pre_computed:  # Kernel tensor is raw tensor, do strict check
+        if dilation_h != 1 or dilation_w != 1:
+            kernel = nn.dilate(kernel, (dilation_h, dilation_w, 1, 1))
+        KH, KW, CI, CO = get_const_tuple(kernel.shape)
+        alpha = KW + tile_size - 1
+        assert HSTR == 1 and WSTR == 1 and KH == KW
+    else:
+        # Kernel tensor is pre-transfomred. This op is created by 
conv2d_alter_op.
+        # Dilation is not supported
+        alpha, _, CI, CO = get_const_tuple(kernel.shape)
+        KH = KW = alpha + 1 - tile_size
+        assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
+
+    pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
+    data_pad = nn.pad(data, (0, pt, pl, 0), (0, pb, pr, 0), name="data_pad")
+
+    r = KW
+    m = tile_size
+    H = (H + pt + pb - KH) // HSTR + 1
+    W = (W + pl + pr - KW) // WSTR + 1
+    nH, nW = (H + m - 1) // m, (W + m - 1) // m
+    P = N * nH * nW
+
+    # Determine whether the shape is available with tensorcore
+    shape_judge = (P % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
+                      (P % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
+                      (P % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
+
+    if shape_judge and use_tensorcore:
+        trans_type = "float16"
+    else:
+        trans_type = data.dtype
+
+    # Compute transform matrix
+    A, _, _ = winograd_transform_matrices(m, r, out_dtype)
+    _, B, G = winograd_transform_matrices(m, r, data.dtype)
+
+    # Transform kernel
+    if not pre_computed:
+        # Check if we are currently tuning, if so we want to avoid counting
+        # prepacking in time costs. Just use a placeholder with the packed 
shape instead.
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kernel_pack = te.placeholder((alpha, alpha, CI, CO),
+                                         dtype=kernel.dtype,
+                                         name='kernel_pack')
+        else:
+            r_kh = te.reduce_axis((0, KH), name='r_kh')
+            r_kw = te.reduce_axis((0, KW), name='r_kw')
+            kernel_pack = te.compute((alpha, alpha, CI, CO), lambda eps, nu, 
ci, co:
+                                     te.sum((kernel[r_kh][r_kw][ci][co]) *
+                                            G[eps][r_kh] * G[nu][r_kw],
+                                            axis=[r_kh, r_kw]), 
name='kernel_pack')
+    else:
+        kernel_pack = kernel
+
+    idxdiv = tvm.tir.indexdiv
+    idxmod = tvm.tir.indexmod
+
+    # Pack input tile
+    input_tile = te.compute((P, CI, alpha, alpha), lambda p, c, eps, nu:
+                            data_pad[idxdiv(p, (nH * nW)),
+                                     idxmod(idxdiv(p, nW), nH) * m + eps,
+                                     idxmod(p, nW) * m + nu,
+                                     c], name='d')
+
+    # Transform data
+    r_a = te.reduce_axis((0, alpha), 'r_a')
+    r_b = te.reduce_axis((0, alpha), 'r_b')
+    data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci:
+                           te.sum(input_tile[p][ci][r_a][r_b] * B[r_a][eps] * 
B[r_b][nu],
+                                  axis=[r_a, r_b]), name='data_pack')
+
+    # Convert data type of input feature maps and weights for tensorcore
+    Transdata = te.compute(
+        data_pack.shape, lambda eps, nu, p, ci: data_pack[eps, nu, p, 
ci].astype(trans_type))
+    TransFilter = te.compute(
+        kernel_pack.shape, lambda eps, nu, ci, co: kernel_pack[eps, nu, ci, 
co].astype(trans_type))
+
+    # Do batch gemm
+    ci = te.reduce_axis((0, CI), name='ci')
+    bgemm = te.compute((alpha, alpha, P, CO), lambda eps, nu, p, co:
+                       te.sum((Transdata[eps][nu][p][ci]).astype(out_dtype) *
+                              (TransFilter[eps][nu][ci][co]).astype(out_dtype),

Review comment:
       Tensorcore supports fp16 input, and fp16/fp32 output. If the input data 
type is fp32, input should be converted to fp16 to before running on Tensor 
Core. This data conversion only causes a limited loss of accuracy.
    
   We added a test file (test_topi_conv2d_nhwc_winograd.py) in this PR. It 
tests the correctness and accuracy of different shape Conv2d with direct and 
tensorcore winograd, and it can pass when the rtol is 2e-3.
   
   
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to