This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f92cfe  Improve CUDA conv2d_transpose_nchw (#4762)
4f92cfe is described below

commit 4f92cfe5e2f1524285d4111753ad0e24ebc4e318
Author: Alex Gladkov <gladk...@lab126.com>
AuthorDate: Wed Jan 22 05:41:46 2020 -0800

    Improve CUDA conv2d_transpose_nchw (#4762)
    
    - combine pad and dilate;
    - fix for the issue 
https://discuss.tvm.ai/t/compile-error-for-cuda-target/4164
    - fix for the issue https://github.com/apache/incubator-tvm/pull/4472
---
 topi/python/topi/cuda/conv2d_transpose_nchw.py     | 136 +++++++++------------
 .../python/test_topi_conv2d_transpose_nchw.py      |  29 +++--
 2 files changed, 74 insertions(+), 91 deletions(-)

diff --git a/topi/python/topi/cuda/conv2d_transpose_nchw.py 
b/topi/python/topi/cuda/conv2d_transpose_nchw.py
index 274dfb0..26bc261 100644
--- a/topi/python/topi/cuda/conv2d_transpose_nchw.py
+++ b/topi/python/topi/cuda/conv2d_transpose_nchw.py
@@ -21,11 +21,11 @@ import tvm
 from tvm import autotvm
 from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
 from .. import nn, generic
-from ..util import equal_const_int, get_const_tuple, traverse_inline
+from ..util import get_const_tuple, traverse_inline
 
 
 @autotvm.task.register_topi_compute(nn.conv2d_transpose_nchw, ['cuda', 'gpu'], 
"direct")
-def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, padding, 
out_dtype):
+def conv2d_transpose_nchw_cuda(cfg, data, kernel, stride, padding, out_dtype):
     """Transposed 2D convolution nchw forward operator.
 
     Parameters
@@ -48,67 +48,58 @@ def conv2d_transpose_nchw_cuda(cfg, Input, Filter, strides, 
padding, out_dtype):
     Output : tvm.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    batch, in_c, in_h, in_w = get_const_tuple(Input.shape)
-    _, out_c, filter_h, filter_w = get_const_tuple(Filter.shape)
-    stride_h, stride_w = strides
-
-    # attach stride info to config, this is used in schedule space definition
-    cfg.stride = strides
-
-    # padding stage
-    fpad_top, fpad_left, fpad_bottom, fpad_right = nn.get_pad_tuple(padding, 
(filter_h, filter_w))
-    bpad_top = filter_h - 1 - fpad_top
-    bpad_bottom = filter_h - 1 - fpad_bottom
-    bpad_left = filter_w - 1 - fpad_left
-    bpad_right = filter_w - 1 - fpad_right
-
-    # padding stage
-    FirstPad = nn.pad(Input,
-                      [0, 0, (bpad_top + stride_h - 1) // stride_h,
-                       (bpad_left + stride_w - 1) // stride_w],
-                      [0, 0, (bpad_bottom + stride_h - 1) // stride_h,
-                       (bpad_right + stride_w - 1) // stride_w], 
name='FirstPad')
-
-    idxdiv = tvm.indexdiv
-    idxmod = tvm.indexmod
-    # remove extra padding introduced by dilatation
-    border_h = idxmod(stride_h - idxmod(bpad_top, stride_h), stride_h)
-    border_w = idxmod(stride_w - idxmod(bpad_left, stride_w), stride_w)
-
-    # dilation stage
-    data = FirstPad
-    strides = [1, 1, stride_h, stride_w]
-    n = len(data.shape)
-
-    def _dilate(*indices):
-        not_zero = []
-        index_tuple = []
-        for i in range(n):
-            if not equal_const_int(strides[i], 1):
-                index_tuple.append(idxdiv(indices[i], strides[i]))
-                not_zero.append(idxmod(indices[i], strides[i]).equal(0))
-            else:
-                index_tuple.append(indices[i])
-        if not_zero:
-            not_zero = tvm.all(*not_zero)
-            return tvm.if_then_else(not_zero, data(*index_tuple), 
tvm.const(0.0, data.dtype))
-        return data(*index_tuple)
-
-    # convolution stage
-    out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h
-    out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w
-    dc = tvm.reduce_axis((0, in_c), name='dc')
-    dh = tvm.reduce_axis((0, filter_h), name='dh')
-    dw = tvm.reduce_axis((0, filter_w), name='dw')
-
-    Output = tvm.compute(
-        (batch, out_c, out_h, out_w),
+    batch, inp_channels, inp_height, inp_width = get_const_tuple(data.shape)
+    _, out_channels, kernel_height, kernel_width = 
get_const_tuple(kernel.shape)
+    stride_height, stride_width = stride
+    cfg.stride = stride
+    pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(
+        padding, (kernel_height, kernel_width))
+
+    out_width = (inp_width - 1) * stride_width + \
+                kernel_width - pad_left - pad_right
+    pad_left = kernel_width - 1 - pad_left
+    pad_right = kernel_width - 1 - pad_right
+    dilated_width = stride_width * (inp_width - 1) + 1
+
+    out_height = (inp_height - 1) * stride_height + \
+                 kernel_height - pad_top - pad_bottom
+    pad_top = kernel_height - 1 - pad_top
+    pad_bottom = kernel_height - 1 - pad_bottom
+    dilated_height = stride_height * (inp_height - 1) + 1
+
+    # compute pad
+    data = tvm.compute(
+        (batch, inp_channels,
+         pad_top + dilated_height + pad_bottom,
+         pad_left + dilated_width + pad_right),
+        lambda n, c, y, x: tvm.if_then_else(
+            tvm.all(x >= pad_left,
+                    x < pad_left + dilated_width,
+                    tvm.indexmod(x - pad_left, stride_width).equal(0),
+                    y >= pad_top,
+                    y < pad_top + dilated_height,
+                    tvm.indexmod(y - pad_top, stride_height).equal(0)),
+            data[n, c,
+                 tvm.indexdiv(y - pad_top, stride_height),
+                 tvm.indexdiv(x - pad_left, stride_width)],
+            tvm.const(0., "float32")),
+        name='data_pad')
+
+    # compute transposed conv
+    dc = tvm.reduce_axis((0, inp_channels), name='dc')
+    dh = tvm.reduce_axis((0, kernel_height), name='dh')
+    dw = tvm.reduce_axis((0, kernel_width), name='dw')
+    data_out = tvm.compute(
+        (batch, out_channels, out_height, out_width),
         lambda b, c, h, w: tvm.sum(
-            _dilate(b, dc, h + dh + border_h, w + dw + 
border_w).astype(out_dtype) *
-            Filter[dc, c, filter_h - 1 - dh, filter_w - 1 - 
dw].astype(out_dtype),
+            data[b, dc, h + dh, w + dw].astype(out_dtype) *
+            kernel[dc,
+                   c,
+                   kernel_height - 1 - dh,
+                   kernel_width - 1 - dw].astype(out_dtype),
             axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
 
-    return Output
+    return data_out
 
 @autotvm.task.register_topi_schedule(generic.schedule_conv2d_transpose_nchw,
                                      ['cuda', 'gpu'], 'direct')
@@ -140,7 +131,8 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
         else:
             cfg["tile_n"] = SplitEntity([1, 1, 1, 1])
         # split F (output channel dimension)
-        cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
+        if F > 1:
+            cfg["tile_f"] = SplitEntity([-1, 1, 64, 1])
         # split Y (height dimension)
         y_split_factor = 1
         for candidate in range(5, 17):
@@ -185,26 +177,8 @@ def schedule_conv2d_transpose_nchw_cuda(cfg, outs):
                 cfg.define_knob("unroll_explicit", [0, 1])
 
             if cfg.is_fallback:
-                ko = int(kernel.shape[1])
-                kh = int(kernel.shape[2])
-                kw = int(kernel.shape[3])
-                stride_h, stride_w = cfg.stride
-                # Workaround to make CUDA compilation work. Issue #4470
-                # TODO make _fallback_schedule work for all kernel/strides 
combinations
-                #  after issue #4470 is resolved
-                do_fallback = True
-                if ko == 1:
-                    do_fallback = False
-                elif (kh, kw) == (1, 1):
-                    do_fallback = True
-                elif (stride_h, stride_w) == (2, 2):
-                    do_fallback = False
-                elif (kh, kw) == (stride_h, stride_w):
-                    do_fallback = False
-
-                if do_fallback:
-                    N, F, Y, X = get_const_tuple(conv.shape)
-                    _fallback_schedule(N, F, Y, X)
+                N, F, Y, X = get_const_tuple(conv.shape)
+                _fallback_schedule(N, F, Y, X)
 
             ##### space definition end #####
 
diff --git a/topi/tests/python/test_topi_conv2d_transpose_nchw.py 
b/topi/tests/python/test_topi_conv2d_transpose_nchw.py
index 0960760..fb836d4 100644
--- a/topi/tests/python/test_topi_conv2d_transpose_nchw.py
+++ b/topi/tests/python/test_topi_conv2d_transpose_nchw.py
@@ -25,10 +25,13 @@ from topi.util import get_const_tuple
 from common import get_all_backend
 
 def verify_conv2d_transpose_nchw(batch, in_channel, in_size, num_filter, 
kernel, stride, padding):
-    in_height = in_width = in_size
+    in_height, in_width = in_size
+    kernel_height, kernel_width = kernel
+    stride_height, stride_width = stride
+    pad_top, pad_left, pad_bottom, pad_right = padding
 
     A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
-    W = tvm.placeholder((in_channel, num_filter, kernel, kernel), name='W')
+    W = tvm.placeholder((in_channel, num_filter, kernel_height, kernel_width), 
name='W')
 
     a_shape = get_const_tuple(A.shape)
     w_shape = get_const_tuple(W.shape)
@@ -51,7 +54,10 @@ def verify_conv2d_transpose_nchw(batch, in_channel, in_size, 
num_filter, kernel,
             return
         print("Running on target: %s" % device)
         with tvm.target.create(device):
-            B = topi.nn.conv2d_transpose_nchw(A, W, [stride, stride], 
[padding, padding], A.dtype)
+            B = topi.nn.conv2d_transpose_nchw(A, W,
+                                              [stride_height, stride_width],
+                                              [pad_top, pad_left, pad_bottom, 
pad_right],
+                                              A.dtype)
             C = topi.nn.relu(B)
             s1 = topi.generic.schedule_conv2d_transpose_nchw([B])
             s2 = topi.generic.schedule_conv2d_transpose_nchw([C])
@@ -66,18 +72,21 @@ def verify_conv2d_transpose_nchw(batch, in_channel, 
in_size, num_filter, kernel,
         func2(a, w, c)
         tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
         tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
-
     for device in get_all_backend():
         check_device(device)
 
 
 def test_conv2d_transpose_nchw():
-    verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 1, 0)
-    verify_conv2d_transpose_nchw(1, 3, 224, 32, 3, 2, 1)
-    verify_conv2d_transpose_nchw(1, 3, 224, 32, 2, 2, 0)
-    verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 1, 0)
-    verify_conv2d_transpose_nchw(1, 32, 32, 128, 5, 2, 1)
-
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  1, (1, 1), (1, 1), (0, 0, 
0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (1, 1), (0, 0, 
0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (3, 3), (0, 0, 
0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (1, 1), (0, 0, 
0, 0))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (3, 3), (2, 2), (1, 1, 
1, 1))
+    verify_conv2d_transpose_nchw(1, 3, (224, 224),  32, (2, 2), (2, 2), (0, 0, 
0, 0))
+    verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (1, 1), (0, 0, 
0, 0))
+    verify_conv2d_transpose_nchw(1, 32, (32, 32), 128, (5, 5), (2, 2), (1, 1, 
1, 1))
+    verify_conv2d_transpose_nchw(16, 32, (8192, 1), 8, (31, 1), (2, 1), (14, 
0, 15, 0))
+    verify_conv2d_transpose_nchw(16, 512, (8, 1), 128, (31, 1), (2, 1), (14, 
0, 15, 0))
 
 if __name__ == "__main__":
     test_conv2d_transpose_nchw()

Reply via email to