This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 3df8d56  [Topi] Tensorcore support for Conv3D (#5284)
3df8d56 is described below

commit 3df8d560f2b6d34ba43a069cd5809560d2c96983
Author: Josh Fromm <[email protected]>
AuthorDate: Mon Apr 13 10:49:17 2020 -0700

    [Topi] Tensorcore support for Conv3D (#5284)
    
    * one weird trick.
    
    * Added schedule knob for different workloads.
    
    * Initial conv3d tensorcore working.
    
    * Added conv3d tensorcore strategy.
    
    * Added layout conversion to tensorcore friendly format for conv2d and 
conv3d.
    
    * Add target name check.
    
    * Fixed bad names and depthwise check.
    
    * Removed duplicated attribute assignment.
---
 python/tvm/relay/op/nn/_nn.py                      |  52 ++++++++-
 python/tvm/relay/op/strategy/cuda.py               |  47 +++++---
 topi/python/topi/cuda/__init__.py                  |   1 +
 topi/python/topi/cuda/conv2d_nhwc_tensorcore.py    |   2 +-
 ...wc_tensorcore.py => conv3d_ndhwc_tensorcore.py} | 127 +++++++++++----------
 topi/python/topi/cuda/conv3d_winograd.py           |  16 ++-
 .../python/test_topi_conv3d_ndhwc_tensorcore.py    | 127 +++++++++++++++++++++
 7 files changed, 287 insertions(+), 85 deletions(-)

diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 51e7128..5f6aa89 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -27,6 +27,7 @@ from .. import op as reg
 from .. import strategy
 from ..op import OpPattern
 from .._tensor import elemwise_shape_func
+from ..strategy.generic import is_depthwise_conv2d
 
 # relu
 reg.register_broadcast_schedule("nn.relu")
@@ -139,13 +140,21 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
     # pylint: disable=import-outside-toplevel
     from tvm import relay
     data, weight = inputs
-    assert desired_layout == 'NCHW', \
-            "Currently only transformation to NCHW layout is supported."
+    new_attrs = dict(attrs)
+    new_attrs['data_layout'] = desired_layout
     if desired_layout == 'NCHW':
-        new_attrs = dict(attrs)
-        new_attrs['data_layout'] = desired_layout
         new_attrs['kernel_layout'] = 'OIHW'
         return relay.nn.conv2d(data, weight, **new_attrs)
+    elif desired_layout == 'NHWC':
+        # Check for depthwise convolution.
+        if is_depthwise_conv2d(data.shape, attrs['data_layout'], weight.shape,
+                               attrs['kernel_layout'], attrs['groups']):
+            new_attrs['kernel_layout'] = 'HWOI'
+        else:
+            new_attrs['kernel_layout'] = 'HWIO'
+        return relay.nn.conv2d(data, weight, **new_attrs)
+    else:
+        assert "Layout %s is not yet supported." % (desired_layout)
     return None
 
 
@@ -183,6 +192,41 @@ def alter_op_layout_conv3d(attrs, inputs, tinfos, 
out_type):
     """Alternate the layout of conv3d"""
     return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type)
 
[email protected]_convert_op_layout("nn.conv3d")
+def convert_conv3d(attrs, inputs, tinfos, desired_layout):
+    """Convert Layout pass registration for conv3d op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    tinfos : list of types
+        List of input and output types
+    desired_layout : str
+        The desired layout
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The transformed expr
+    """
+    # pylint: disable=import-outside-toplevel
+    from tvm import relay
+    data, weight = inputs
+    new_attrs = dict(attrs)
+    new_attrs['data_layout'] = desired_layout
+    if desired_layout == 'NCDHW':
+        new_attrs['kernel_layout'] = 'OIDHW'
+        return relay.nn.conv3d(data, weight, **new_attrs)
+    elif desired_layout == "NDHWC":
+        new_attrs['kernel_layout'] = 'DHWIO'
+        return relay.nn.conv3d(data, weight, **new_attrs)
+    else:
+        assert "Layout %s is not yet supported" % desired_layout
+    return None
+
 # conv3d_winograd related operators
 reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform",
                       
strategy.conv3d_winograd_without_weight_transfrom_strategy)
diff --git a/python/tvm/relay/op/strategy/cuda.py 
b/python/tvm/relay/op/strategy/cuda.py
index 845be66..4e5088f 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -138,15 +138,16 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
                 name="conv2d_nhwc.cuda")
             N, _, _, _ = get_const_tuple(data.shape)
             _, _, CI, CO = get_const_tuple(kernel.shape)
-            if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
-                if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
-                        (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
-                        (N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
-                    strategy.add_implementation(
-                        wrap_compute_conv2d(topi.cuda.conv2d_nhwc_tensorcore),
-                        
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
-                        name="conv2d_nhwc_tensorcore.cuda",
-                        plevel=20)
+            if target.target_name == "cuda":
+                if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
+                    if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
+                            (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
+                            (N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
+                        strategy.add_implementation(
+                            
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_tensorcore),
+                            
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_tensorcore),
+                            name="conv2d_nhwc_tensorcore.cuda",
+                            plevel=20)
         elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
             assert kernel_layout == "OIHW4o4i"
             strategy.add_implementation(
@@ -170,7 +171,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
             strategy.add_implementation(
                 wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
                 wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
-                name="dpethwise_nchw.cuda")
+                name="depthwise_conv2d_nchw.cuda")
         elif layout == "NHWC":
             assert kernel_layout == "HWOI"
             strategy.add_implementation(
@@ -249,7 +250,7 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, 
target):
 def conv3d_strategy_cuda(attrs, inputs, out_type, target):
     """conv3d cuda strategy"""
     strategy = _op.OpStrategy()
-    _, kernel = inputs
+    data, kernel = inputs
     layout = attrs.data_layout
     _, stride_h, stride_w = attrs.get_int_tuple("strides")
     _, dilation_h, dilation_w = attrs.get_int_tuple("dilation")
@@ -268,11 +269,25 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
                 wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw_winograd),
                 name="conv3d_ncdhw_winograd.cuda",
                 plevel=5)
-    else: # layout == "NDHWC":
-        
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
-                                    
wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
-                                    name="conv3d_ndhwc.cuda",
-                                    plevel=10)
+    else:  # layout == "NDHWC":
+        strategy.add_implementation(
+            wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
+            wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
+            name="conv3d_ndhwc.cuda",
+            plevel=10)
+        N, _, _, _, _ = get_const_tuple(data.shape)
+        _, _, _, CI, CO = get_const_tuple(kernel.shape)
+        if target.target_name == "cuda":
+            if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
+                if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
+                (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
+                (N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
+                    strategy.add_implementation(
+                        wrap_compute_conv3d(topi.cuda.conv3d_ndhwc_tensorcore),
+                        
wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc_tensorcore),
+                        name="conv3d_ndhwc_tensorcore.cuda",
+                        plevel=20)
+
     if target.target_name == "cuda" and "cudnn" in target.libs:
         
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True),
                                     
wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn),
diff --git a/topi/python/topi/cuda/__init__.py 
b/topi/python/topi/cuda/__init__.py
index c20e257..2b7a845 100644
--- a/topi/python/topi/cuda/__init__.py
+++ b/topi/python/topi/cuda/__init__.py
@@ -46,4 +46,5 @@ from .nms import get_valid_counts, non_max_suppression
 from .rcnn import *
 from .sort import *
 from .conv2d_nhwc_tensorcore import *
+from .conv3d_ndhwc_tensorcore import *
 from .dense_tensorcore import *
diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py 
b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py
index 8f8f93d..790db0f 100644
--- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py
+++ b/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py
@@ -70,7 +70,7 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, 
dilation, out_dtyp
     # convert data type of input feature maps and weights
     TransPaddedInput = te.compute(
         PaddedInput.shape,
-        lambda h, w, i, o: PaddedInput[h, w, i, o].astype('float16'))
+        lambda n, h, w, c: PaddedInput[n, h, w, c].astype('float16'))
     TransFilter = te.compute(
         Filter.shape, lambda h, w, i, o: Filter[h, w, i, o].astype('float16'))
     Output = te.compute(
diff --git a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py 
b/topi/python/topi/cuda/conv3d_ndhwc_tensorcore.py
similarity index 71%
copy from topi/python/topi/cuda/conv2d_nhwc_tensorcore.py
copy to topi/python/topi/cuda/conv3d_ndhwc_tensorcore.py
index 8f8f93d..e3c7513 100644
--- a/topi/python/topi/cuda/conv2d_nhwc_tensorcore.py
+++ b/topi/python/topi/cuda/conv3d_ndhwc_tensorcore.py
@@ -23,30 +23,30 @@ from tvm import te
 from tvm import autotvm
 from ..util import get_const_tuple, traverse_inline, simplify
 from ..nn.pad import pad
-from ..nn.util import get_pad_tuple
+from ..nn.util import get_pad_tuple3d
 from .tensor_intrin import intrin_wmma_load_matrix_A
 from .tensor_intrin import intrin_wmma_load_matrix_W
 from .tensor_intrin import intrin_wmma_store_matrix
 from .tensor_intrin import intrin_wmma_gemm
 
 
-def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, 
out_dtype):
-    """Compute declaration for tensorcore"""
-    assert isinstance(stride, int) or len(stride) == 2
-    assert isinstance(dilation, int) or len(dilation) == 2
+def ndhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, 
out_dtype):
+    """Compute declaration for conv3d tensorcore function"""
+    assert isinstance(stride, int) or len(stride) == 3
+    assert isinstance(dilation, int) or len(dilation) == 3
 
     if isinstance(stride, int):
-        stride_h = stride_w = stride
+        stride_d = stride_h = stride_w = stride
     else:
-        stride_h, stride_w = stride
+        stride_d, stride_h, stride_w = stride
 
     if isinstance(dilation, int):
-        dilation_h = dilation_w = dilation
+        dilation_d = dilation_h = dilation_w = dilation
     else:
-        dilation_h, dilation_w = dilation
+        dilation_d, dilation_h, dilation_w = dilation
 
-    batch, in_height, in_width, in_channel = get_const_tuple(Input.shape)
-    kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape)
+    batch, in_depth, in_height, in_width, in_channel = 
get_const_tuple(Input.shape)
+    kernel_d, kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape)
     assert (batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0) 
or \
                (batch % 8 == 0 and in_channel % 16 == 0 and num_filter % 32 == 
0) or \
                (batch % 32 == 0 and in_channel % 16 == 0 and num_filter % 8 == 
0), \
@@ -54,43 +54,50 @@ def nhwc_tensorcore_cuda(cfg, Input, Filter, stride, 
padding, dilation, out_dtyp
                "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) 
for now"
 
     # compute the output shape
+    dilated_kernel_d = (kernel_d - 1) * dilation_d + 1
     dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
     dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
-    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
-        padding, (dilated_kernel_h, dilated_kernel_w))
+    pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = 
get_pad_tuple3d(
+        padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w))
     out_channel = num_filter
+    out_depth = simplify((in_depth - dilated_kernel_d + pad_front + pad_back) 
// stride_d + 1)
     out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) 
// stride_h + 1)
     out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) 
// stride_w + 1)
-    pad_before = [0, pad_top, pad_left, 0]
-    pad_after = [0, pad_down, pad_right, 0]
+    pad_before = [0, pad_front, pad_top, pad_left, 0]
+    pad_after = [0, pad_back, pad_down, pad_right, 0]
     PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
     rc = te.reduce_axis((0, in_channel), name='rc')
+    rz = te.reduce_axis((0, kernel_d), name='rz')
     ry = te.reduce_axis((0, kernel_h), name='ry')
     rx = te.reduce_axis((0, kernel_w), name='rx')
     # convert data type of input feature maps and weights
     TransPaddedInput = te.compute(
         PaddedInput.shape,
-        lambda h, w, i, o: PaddedInput[h, w, i, o].astype('float16'))
+        lambda n, d, h, w, c: PaddedInput[n, d, h, w, c].astype('float16'))
     TransFilter = te.compute(
-        Filter.shape, lambda h, w, i, o: Filter[h, w, i, o].astype('float16'))
+        Filter.shape, lambda d, h, w, i, o: Filter[d, h, w, i, 
o].astype('float16'))
     Output = te.compute(
-        (batch, out_height, out_width, out_channel),
-        lambda nn, yy, xx, ff: te.sum(
-            TransPaddedInput[nn, yy * stride_h + ry * dilation_h,
-                             xx * stride_w + rx * dilation_w, 
rc].astype(out_dtype) *
-            TransFilter[ry, rx, rc, ff].astype(out_dtype), axis=[ry, rx, rc]),
-        name="Conv2dOutput", tag="conv2d_nhwc_tensorcore")
+        (batch, out_depth, out_height, out_width, out_channel),
+        lambda nn, zz, yy, xx, ff: te.sum(
+            TransPaddedInput[nn,
+                             zz * stride_d + rz * dilation_d,
+                             yy * stride_h + ry * dilation_h,
+                             xx * stride_w + rx * dilation_w,
+                             rc].astype(out_dtype) *
+            TransFilter[rz, ry, rx, rc, ff].astype(out_dtype),
+            axis=[rz, ry, rx, rc]),
+        name="Conv3dOutput", tag="conv3d_ndhwc_tensorcore")
     return Output
 
 
-def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
+def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
     """Schedule tensorcore template"""
-    kh, kw, ic = s[Conv].op.reduce_axis
+    kd, kh, kw, ic = s[Conv].op.reduce_axis
     out_dtype = Conv.dtype
     trans_paddata, kernel = s[Conv].op.input_tensors
     in_dtype = trans_paddata.dtype
-    batch, _, _, _ = get_const_tuple(Conv.shape)
-    _, _, _, out_channels = get_const_tuple(kernel.shape)
+    batch, _, _, _, _ = get_const_tuple(Conv.shape)
+    _, _, _, _, out_channels = get_const_tuple(kernel.shape)
     paddata = s[trans_paddata].op.input_tensors
 
     # inline the pad and dtype transform
@@ -134,7 +141,7 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
     target = tvm.target.Target.current()
     if cfg.is_fallback:
         ref_log = autotvm.tophub.load_reference_log(
-            target.target_name, target.model, 'conv2d_nhwc_tensorcore.cuda')
+            target.target_name, target.model, 'conv3d_ndhwc_tensorcore.cuda')
         cfg.fallback_with_reference_log(ref_log)
 
     block_row_warps = cfg["block_row_warps"].val
@@ -172,16 +179,16 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
     block_factor_n = wmma_m * warp_row_tiles * block_row_warps
     block_factor_o = wmma_n * warp_col_tiles * block_col_warps
     CS_align = block_factor_o + offset
-    AS_strides = get_strides([1, 1, AS_align, 1])
-    AL_strides = get_strides([1, 1, wmma_k, 1])
+    AS_strides = get_strides([1, 1, 1, AS_align, 1])
+    AL_strides = get_strides([1, 1, 1, wmma_k, 1])
     WS_strides = get_strides([WS_align, 1])
     WL_strides = get_strides([wmma_n * warp_col_tiles, 1])
-    CL_strides = get_strides([1, 1, wmma_n * warp_col_tiles, 1])
-    CS_strides = get_strides([1, 1, CS_align, 1])
+    CL_strides = get_strides([1, 1, 1, wmma_n * warp_col_tiles, 1])
+    CS_strides = get_strides([1, 1, 1, CS_align, 1])
 
     # Schedule for output
-    nc, hc, wc, oc = output.op.axis
-    block_k = s[output].fuse(hc, wc)
+    nc, dc, hc, wc, oc = output.op.axis
+    block_k = s[output].fuse(dc, hc, wc)
     s[output].bind(block_k, block_z)
     block_i, nc = s[output].split(nc, factor=block_factor_n)
     block_j, oc = s[output].split(oc, factor=block_factor_o)
@@ -200,8 +207,8 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
 
     # Schedule wmma store
     s[OL].compute_at(s[output], block_j)
-    nc, hc, wc, oc = OL.op.axis
-    s[OL].reorder(hc, wc, nc, oc)
+    nc, dc, hc, wc, oc = OL.op.axis
+    s[OL].reorder(dc, hc, wc, nc, oc)
     s[OL].storage_align(wc, CS_align - 1, CS_align)
     oc, ooc = s[OL].split(oc, factor=wmma_n)
     oc, oci = s[OL].split(oc, factor=warp_col_tiles)
@@ -215,23 +222,23 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
 
     # Schedule wmma computation
     s[ConvF].compute_at(s[OL], oc)
-    n, h, w, o = ConvF.op.axis
+    n, d, h, w, o = ConvF.op.axis
     n, nnf = s[ConvF].split(n, factor=wmma_m)
     o, oof = s[ConvF].split(o, factor=wmma_n)
     ic, ii = s[ConvF].split(ic, factor=wmma_k)
     ko, ki = s[ConvF].split(ic, factor=chunk)
-    s[ConvF].reorder(kh, kw, ko, ki, n, o, nnf, oof, ii)
+    s[ConvF].reorder(kd, kh, kw, ko, ki, n, o, nnf, oof, ii)
 
     s[AF].compute_at(s[ConvF], ki)
     s[WF].compute_at(s[ConvF], ki)
 
     # Schedule wmma load
-    n, h, w, i = AF.op.axis
+    n, d, h, w, i = AF.op.axis
     n, nn = s[AF].split(n, factor=wmma_m)
     i, ii = s[AF].split(i, factor=wmma_k)
     s[AF].reorder(n, i, nn, ii)
 
-    kh, kw, i, o = WF.op.axis
+    kd, kh, kw, i, o = WF.op.axis
     i, ii = s[WF].split(i, factor=wmma_k)
     o, oo = s[WF].split(o, factor=wmma_n)
     s[WF].reorder(o, i, oo)
@@ -241,8 +248,8 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
     s[AS].compute_at(s[ConvF], ko)
 
     # Schedule for data's share memory
-    n, h, w, i = AS.op.axis
-    s[AS].reorder(h, w, n, i)
+    n, d, h, w, i = AS.op.axis
+    s[AS].reorder(d, h, w, n, i)
     s[AS].storage_align(w, AS_align - 1, AS_align)
     t = s[AS].fuse(n, i)
     t, ti = s[AS].split(t, factor=vector_width)
@@ -255,7 +262,7 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
     s[AS].vectorize(ti)
 
     # Schedule for kernel's share memory
-    kh, kw, ic, o = WS.op.axis
+    kd, kh, kw, ic, o = WS.op.axis
     t = s[WS].fuse(ic, o)
     s[WS].storage_align(ic, WS_align - 1, WS_align)
     t, ti = s[WS].split(t, factor=vector_width)
@@ -270,18 +277,18 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
     shape = (wmma_m, wmma_n, wmma_k)
 
     # tensorize the wmma process
-    AS_shape = (wmma_m, 1, 1, wmma_k)
-    AL_shape = (wmma_m, 1, 1, wmma_k)
+    AS_shape = (wmma_m, 1, 1, 1, wmma_k)
+    AL_shape = (wmma_m, 1, 1, 1, wmma_k)
     WS_shape = (wmma_k, wmma_n)
     WL_shape = (wmma_k, wmma_n)
-    CL_shape = (wmma_m, 1, 1, wmma_n)
-    CS_shape = (wmma_m, 1, 1, wmma_n)
+    CL_shape = (wmma_m, 1, 1, 1, wmma_n)
+    CS_shape = (wmma_m, 1, 1, 1, wmma_n)
 
     AL_gemm = te.placeholder(AL_shape, name='A', dtype=in_dtype)
     WL_gemm = te.placeholder(WL_shape, name='B', dtype=in_dtype)
     k_gemm = te.reduce_axis((0, wmma_k), name="k")
-    CL_compute = te.compute(CL_shape, lambda ii, t0, t1, jj:
-                            te.sum(AL_gemm[ii, t0, t1, 
k_gemm].astype(out_dtype) * \
+    CL_compute = te.compute(CL_shape, lambda ii, t0, t1, t2, jj:
+                            te.sum(AL_gemm[ii, t0, t1, t2, 
k_gemm].astype(out_dtype) * \
                                    WL_gemm[k_gemm, jj].astype(out_dtype), 
axis=k_gemm),
                             name='C')
 
@@ -294,25 +301,25 @@ def schedule_nhwc_tensorcore_cuda(cfg, s, Conv):
     s[ConvF].tensorize(nnf, intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, 
AL_strides,
                                              WL_strides, CL_strides, shape))
 
-    N, OH, OW, CO = get_const_tuple(output.shape)
-    KH, KW, CI, _ = get_const_tuple(kernel.shape)
-    cfg.add_flop(2 * N * OH * OW * CO * CI * KH * KW)
+    N, OD, OH, OW, CO = get_const_tuple(output.shape)
+    KD, KH, KW, CI, _ = get_const_tuple(kernel.shape)
+    cfg.add_flop(2 * N * OD * OH * OW * CO * CI * KD * KH * KW)
 
 
[email protected]_topi_compute("conv2d_nhwc_tensorcore.cuda")
-def conv2d_nhwc_tensorcore(cfg, data, kernel, strides, padding, dilation, 
out_dtype):
-    """Compute conv2d with tensorcore for NCHW layout"""
-    return nhwc_tensorcore_cuda(cfg, data, kernel, strides, padding, dilation, 
out_dtype)
[email protected]_topi_compute("conv3d_ndhwc_tensorcore.cuda")
+def conv3d_ndhwc_tensorcore(cfg, data, kernel, strides, padding, dilation, 
out_dtype):
+    """Compute conv3d with tensorcore for NDHWC layout"""
+    return ndhwc_tensorcore_cuda(cfg, data, kernel, strides, padding, 
dilation, out_dtype)
 
 
[email protected]_topi_schedule("conv2d_nhwc_tensorcore.cuda")
-def schedule_conv2d_nhwc_tensorcore(cfg, outs):
[email protected]_topi_schedule("conv3d_ndhwc_tensorcore.cuda")
+def schedule_conv3d_ndhwc_tensorcore(cfg, outs):
     """TOPI schedule callback"""
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
-        if 'conv2d_nhwc_tensorcore' in op.tag:
-            schedule_nhwc_tensorcore_cuda(cfg, s, op.output(0))
+        if 'conv3d_ndhwc_tensorcore' in op.tag:
+            schedule_ndhwc_tensorcore_cuda(cfg, s, op.output(0))
 
     traverse_inline(s, outs[0].op, _callback)
     return s
diff --git a/topi/python/topi/cuda/conv3d_winograd.py 
b/topi/python/topi/cuda/conv3d_winograd.py
index c9e8446..5876243 100644
--- a/topi/python/topi/cuda/conv3d_winograd.py
+++ b/topi/python/topi/cuda/conv3d_winograd.py
@@ -493,13 +493,17 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, 
pre_computed):
     BB = s.cache_read(B0, 'shared', [OL])
 
     b = s[bgemm].fuse(b1, b2)
-    y = s[bgemm].fuse(z, y)
-
+    # Allow two different tiling strategies as both seem
+    # to work best in different cases.
+    cfg.define_knob("unroll_axis", [0, 1])
     # tile and bind spatial axes
     bgemm_scope, b = s[bgemm].split(b, nparts=1)
     bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
-    by, vy, ty, yi = cfg["tile_y"].apply(s, C, y)
-    bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, C, z)
+    if cfg['unroll_axis'].val:
+        bx, vx, tx, xi = cfg["tile_x"].apply(s, C, y)
+    else:
+        bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
     s[C].bind(bz, te.thread_axis("blockIdx.z"))
     s[C].bind(by, te.thread_axis("blockIdx.y"))
     s[C].bind(bx, te.thread_axis("blockIdx.x"))
@@ -510,6 +514,10 @@ def schedule_winograd_no_depth_cuda(cfg, s, output, 
pre_computed):
     s[C].bind(ty, te.thread_axis("threadIdx.y"))
     s[C].bind(tx, te.thread_axis("threadIdx.x"))
     s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)
+    if cfg['unroll_axis'].val:
+        s[C].unroll(x)
+    else:
+        s[C].unroll(y)
 
     # tile reduction axes
     s[OL].compute_at(s[C], tx)
diff --git a/topi/tests/python/test_topi_conv3d_ndhwc_tensorcore.py 
b/topi/tests/python/test_topi_conv3d_ndhwc_tensorcore.py
new file mode 100644
index 0000000..f98550f
--- /dev/null
+++ b/topi/tests/python/test_topi_conv3d_ndhwc_tensorcore.py
@@ -0,0 +1,127 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, too-many-locals, too-many-arguments
+"""Example code to do convolution."""
+
+import numpy as np
+import tvm
+import topi
+import topi.testing
+from tvm import te
+from tvm.contrib.pickle_memoize import memoize
+from tvm.contrib import nvcc
+from topi.nn.util import get_pad_tuple3d
+from topi.util import get_const_tuple
+
+
+_conv3d_ndhwc_tensorcore_implement = {
+    "cuda": (topi.cuda.conv3d_ndhwc_tensorcore, 
topi.cuda.schedule_conv3d_ndhwc_tensorcore)
+}
+
+
+def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride,
+                        padding, dilation=1, add_bias=False, add_relu=False, 
devices='cuda'):
+    """Test the conv3d with tensorcore for ndhwc layout"""
+    pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = 
get_pad_tuple3d(
+        padding, (kernel, kernel, kernel))
+    padding_sum = pad_front + pad_top + pad_left + pad_back + pad_bottom + 
pad_right
+    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (
+        batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, 
dilation))
+
+    in_depth = in_height = in_width = in_size
+
+    A = te.placeholder((batch, in_depth, in_height, in_width, in_channel), 
name='A')
+    W = te.placeholder((kernel, kernel, kernel, in_channel, num_filter), 
name='W')
+    bias = te.placeholder((1, 1, 1, 1, num_filter), name='bias')
+
+    a_shape = get_const_tuple(A.shape)
+    w_shape = get_const_tuple(W.shape)
+    bias_shape = get_const_tuple(bias.shape)
+    dtype = A.dtype
+
+    @memoize("topi.tests.test_topi_conv3d_ndhwc.verify_conv3d_ndhwc")
+    def get_ref_data():
+        a_np = np.random.uniform(size=a_shape).astype(dtype)
+        w_np = np.random.uniform(size=w_shape).astype(dtype)
+        b_np = np.random.uniform(size=bias_shape).astype(dtype)
+        dw_np = topi.testing.dilate_python(w_np, (1, 1, 1, dilation, dilation))
+        c_np = topi.testing.conv3d_ndhwc_python(a_np, dw_np, stride, padding)
+        if add_bias:
+            b_np = np.random.uniform(size=bias_shape).astype(dtype)
+            c_np += b_np
+        if add_relu:
+            c_np = np.maximum(c_np, 0)
+        return a_np, w_np, b_np, c_np
+
+    a_np, w_np, b_np, c_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        if not nvcc.have_tensorcore(ctx.compute_version):
+            print("skip because gpu does not support Tensor Cores")
+            return
+        print("Running on target: %s" % device)
+        with tvm.target.create(device):
+            fcompute, fschedule = topi.testing.dispatch(device, 
_conv3d_ndhwc_tensorcore_implement)
+            C = fcompute(A, W, stride, padding, dilation, 'float32')
+            if add_bias:
+                C = topi.add(C, bias)
+            if add_relu:
+                C = topi.nn.relu(C)
+            s = fschedule([C])
+
+        a = tvm.nd.array(a_np, ctx)
+        w = tvm.nd.array(w_np, ctx)
+        b = tvm.nd.array(b_np, ctx)
+        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), 
ctx)
+        if add_bias:
+            func = tvm.build(s, [A, W, bias, C], device, 
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
+                batch, in_channel, in_size, num_filter, kernel, stride, 
padding_sum, dilation))
+            func(a, w, b, c)
+        else:
+            func = tvm.build(s, [A, W, C], device, 
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (
+                batch, in_channel, in_size, num_filter, kernel, stride, 
padding_sum, dilation))
+            func(a, w, c)
+
+        rtol = 1e-3
+        tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)
+
+    check_device(devices)
+
+
+def test_conv3d_ndhwc_tensorcore():
+    """Test the conv3d with tensorcore for ndhwc layout"""
+    verify_conv3d_ndhwc(16, 16, 14, 16, 3, 1, 1)
+    verify_conv3d_ndhwc(16, 64, 7, 64, 7, 1, 3)
+    verify_conv3d_ndhwc(16, 32, 7, 32, 7, 1, 3)
+
+    verify_conv3d_ndhwc(32, 16, 14, 16, 3, 1, 1, add_bias=True)
+    verify_conv3d_ndhwc(32, 16, 14, 16, 3, 1, 1, add_relu=True)
+    verify_conv3d_ndhwc(32, 16, 14, 16, 3, 1, 1, add_relu=True, add_bias=True)
+
+    verify_conv3d_ndhwc(16, 16, 17, 16, 7, 1, (3, 3, 3, 2, 2, 2))
+    verify_conv3d_ndhwc(16, 16, 17, 16, 7, 1, "SAME")
+    verify_conv3d_ndhwc(8, 16, 35, 32, 5, 1, "VALID")
+    verify_conv3d_ndhwc(16, 32, 16, 32, 3, 1, (1, 1, 1, 1, 1, 1))
+    verify_conv3d_ndhwc(16, 16, 12, 16, 3, 1, (1, 1, 1, 1, 1, 1))
+
+
+if __name__ == "__main__":
+    test_conv3d_ndhwc_tensorcore()

Reply via email to