This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d35b858  [CUDNN] Support gradient kernels (#9986)
d35b858 is described below

commit d35b858ceb5c7157c3faa78c88eb8f8971ba5e96
Author: Masahiro Masuda <[email protected]>
AuthorDate: Sun Jan 23 06:58:31 2022 +0900

    [CUDNN] Support gradient kernels (#9986)
    
    * Dgrad nchw, nhwc, fp16 working
    
    commit 426e5dca446a27da49270f45171b58f1bfa21fa9
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Jan 18 11:48:53 2022 +0900
    
        black
    
    commit 211a58b80f4d0f0b5b0230720e41f35e50cb1eaf
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Jan 18 11:43:52 2022 +0900
    
        fp16 also works
    
    commit c2a34d473b063873628bff00e51a44cd8e4d0e4f
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Jan 18 11:36:36 2022 +0900
    
        nhwc test also worked
    
    commit c0609ab147fef30c230a94d16b6c1ba35f7dd9c0
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Jan 18 11:21:23 2022 +0900
    
        nchw test worked
    
    commit 2bf68c72763708151e9f49f09916a210b2547be8
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Jan 18 10:41:35 2022 +0900
    
        add test stub
    
    commit c86b1288d5e371f12cba4e1b1866966cb9264401
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Jan 18 10:32:09 2022 +0900
    
        add python definition stub
    
    commit 3166952f9673376801bf4b5b39eeb6f89452f30a
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Jan 18 06:57:18 2022 +0900
    
        bwd filter compiled
    
    commit e311ba3d05c5f9424ecb952cb5a520ce81a0828a
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Jan 18 06:27:55 2022 +0900
    
        dgrad compiled
    
    commit 47f35beb5eeeb7cbf9f6ec7cf8f5c80c65e8da46
    Author: Masahiro Masuda <[email protected]>
    Date:   Tue Jan 18 06:16:43 2022 +0900
    
        add dgrad stub
    
    commit ebed032d15b1c3895f541c46ce5d80b6dd769034
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon Jan 17 17:01:56 2022 +0900
    
        cpplint
    
    commit 834f54a8c13512130e7d91ca0f54268dc06c5481
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon Jan 17 16:55:58 2022 +0900
    
        remove cudnn get output
    
    commit dcbd9c95fdb8ffef9db9c2350430b270461a31c3
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon Jan 17 16:28:07 2022 +0900
    
        more refactor
    
    commit 146464e8496fff972bdb1687c4e9d432fe3278d5
    Author: Masahiro Masuda <[email protected]>
    Date:   Mon Jan 17 15:57:35 2022 +0900
    
        Introduce SetConvdescriptors to refactor cudnn/conv_forward.cc
    
    * add python function for cudnn wgrad
    
    * adding wgrad test
    
    * black
    
    * wgrad nchw and nhwc worked
    
    * remove bwd algo name stuff
    
    * compute output shape properly
    
    * swap arg order in wgrad
    
    * add kernel size arg in test
    
    * black
    
    * cleanup
    
    * more fix
    
    * fix dgrad test
    
    * support running relay conv2d_backward_weight directly with cudnn
    
    * black
    
    * refactor reference function to support nhwc
    
    * removed unused function
    
    * lint
    
    * enable offloading conv2d_transpose to cudnn dgrad
    
    * relax tol
    
    * name fix, remove print
---
 python/tvm/contrib/cudnn.py                        | 460 +++++++++++++++++++--
 python/tvm/relay/op/nn/_nn.py                      |   4 +
 python/tvm/relay/op/strategy/cuda.py               |  28 ++
 python/tvm/relay/op/strategy/generic.py            |  38 ++
 python/tvm/topi/cuda/conv2d.py                     |  19 +
 python/tvm/topi/cuda/conv2d_transpose_nchw.py      |   8 +
 python/tvm/topi/nn/conv2d_transpose.py             |   1 -
 python/tvm/topi/testing/__init__.py                |   2 +-
 .../topi/testing/conv2d_backcward_weight_python.py |  44 +-
 python/tvm/topi/testing/conv2d_transpose_python.py |   4 +-
 src/relay/op/nn/convolution.cc                     |   1 -
 src/runtime/contrib/cudnn/conv_backward.cc         | 265 ++++++++++++
 src/runtime/contrib/cudnn/conv_forward.cc          |   4 +-
 src/runtime/contrib/cudnn/cudnn_utils.h            |   4 +-
 tests/python/contrib/test_cudnn.py                 | 138 +++++++
 tests/python/relay/test_op_grad_level2.py          |  33 +-
 tests/python/relay/test_op_level2.py               |  17 +-
 17 files changed, 996 insertions(+), 74 deletions(-)

diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py
index 9b92c7c..c897de7 100644
--- a/python/tvm/contrib/cudnn.py
+++ b/python/tvm/contrib/cudnn.py
@@ -36,33 +36,6 @@ _FWD_ALGOS = [
     "CUDNN_CONVOLUTION_FWD_ALGO_COUNT",
 ]
 
-_BWD_FILTER_ALGOS = [
-    "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0",
-    # non-deterministic
-    "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1",
-    "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT",
-    "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3",
-    # non-deterministic, algo0 with workspaceS
-    "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD",
-    # not implemented
-    "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED",
-    "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING",
-    "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT",
-]
-
-_BWD_DATA_ALGOS = [
-    "CUDNN_CONVOLUTION_BWD_DATA_ALGO_0",
-    # non-deterministic
-    "CUDNN_CONVOLUTION_BWD_DATA_ALGO_1",
-    "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT",
-    "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING",
-    "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD",
-    "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED",
-    "CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT",
-]
-
-_ALGO_TYPE = ["fwd", "bwd_filter", "bwd_data"]
-
 
 def exists():
     """
@@ -285,7 +258,74 @@ def conv_output_shape(
     return output
 
 
-def conv_find_algo(
+def conv_dgrad_shape(
+    tensor_format, pad, stride, dilation, dy_shape, w_shape, 
output_padding=(0, 0)
+):
+    """Get output shape of conv2d gradient with respect to data
+
+    Paramters
+    ---------
+    tensor_format: int
+        0: CUDNN_TENSOR_NCHW
+        1: CUDNN_TENSOR_NHWC
+    pad: int or list
+        padding
+    stride: int or list
+        stride
+    dilation: int or list
+        dilation
+    dy_shape: list
+        output gradient shape
+    w_shape: list
+        weight shape
+    data_dtype: str
+        data type
+    conv_dtype: str
+        convolution type
+    groups: int
+        number of groups
+
+    Returns
+    -------
+    oshape: list
+        output shape
+    """
+
+    assert len(dy_shape) == len(w_shape)
+    assert len(dy_shape) == 4
+
+    if tensor_format == 0:
+        N = dy_shape[0]
+        C = w_shape[1]
+        dy_shape = dy_shape[2:]
+        w_shape = w_shape[2:]
+    elif tensor_format == 1:
+        N = dy_shape[0]
+        C = w_shape[-1]
+        dy_shape = dy_shape[1:-1]
+        w_shape = w_shape[1:-1]
+    else:
+        raise ValueError("Unsupported CuDNN tensor format: 
'{}'".format(tensor_format))
+
+    input_dims = []
+    for dy_shape_i, w_shape_i, pad_i, stride_i, dilation_i, out_pad in zip(
+        dy_shape, w_shape, pad, stride, dilation, output_padding
+    ):
+        input_dim = (
+            (dy_shape_i - 1) * stride_i - 2 * pad_i + (((w_shape_i - 1) * 
dilation_i) + 1) + out_pad
+        )
+        input_dims.append(input_dim)
+
+    if tensor_format == 0:
+        output = [N, C, *input_dims]
+    else:
+        output = [N, *input_dims, C]
+
+    return output
+
+
+def _conv_find_algo(
+    func_name,
     tensor_format,
     pad,
     stride,
@@ -297,7 +337,46 @@ def conv_find_algo(
     conv_dtype,
     groups=1,
 ):
-    """Choose the best algo for the given input.
+    """
+    Common function to choose the best cudnn convolution algorithm for the 
given input
+    and the convolution type.
+    """
+    dims = len(x_shape)
+    assert dims in (4, 5)
+
+    pad, stride, dilation, xshape, wshape = _prepare_global_func_params(
+        dims - 2, pad, stride, dilation, x_shape, w_shape
+    )
+    yshape = np.array(y_shape, dtype=np.int32)
+    func = tvm._ffi.get_global_func(func_name)
+    return func(
+        tensor_format,
+        dims - 2,
+        _get_np_int32_array_handle(pad),
+        _get_np_int32_array_handle(stride),
+        _get_np_int32_array_handle(dilation),
+        _get_np_int32_array_handle(xshape),
+        _get_np_int32_array_handle(wshape),
+        _get_np_int32_array_handle(yshape),
+        data_dtype,
+        conv_dtype,
+        groups,
+    )
+
+
+def conv_forward_find_algo(
+    tensor_format,
+    pad,
+    stride,
+    dilation,
+    x_shape,
+    w_shape,
+    y_shape,
+    data_dtype,
+    conv_dtype,
+    groups=1,
+):
+    """Choose the best forward algorithm for the given input.
 
     Paramters
     ---------
@@ -329,23 +408,133 @@ def conv_find_algo(
     algo: int
         algo chosen by CUDNN
     """
-    dims = len(x_shape)
-    assert dims in (4, 5)
+    return _conv_find_algo(
+        "tvm.contrib.cudnn.conv.forward_find_algo",
+        tensor_format,
+        pad,
+        stride,
+        dilation,
+        x_shape,
+        w_shape,
+        y_shape,
+        data_dtype,
+        conv_dtype,
+        groups,
+    )
 
-    pad, stride, dilation, xshape, wshape = _prepare_global_func_params(
-        dims - 2, pad, stride, dilation, x_shape, w_shape
+
+def conv_backward_data_find_algo(
+    tensor_format,
+    pad,
+    stride,
+    dilation,
+    dy_shape,
+    w_shape,
+    dx_shape,
+    data_dtype,
+    conv_dtype,
+    groups=1,
+):
+    """Choose the best backward data algorithm for the given input.
+
+    Paramters
+    ---------
+    tensor_format: int
+        0: CUDNN_TENSOR_NCHW
+        1: CUDNN_TENSOR_NHWC
+        2: CUDNN_TENSOR_NCHW_VECT_C
+    pad: int or list
+        padding
+    stride: int or list
+        stride
+    dilation: int or list
+        dilation
+    dy_shape: list
+        output gradient shape
+    w_shape: list
+        weight shape
+    dx_shape: list
+        dgrad shape
+    data_dtype: str
+        data type
+    conv_dtype: str
+        convolution type
+    groups: int
+        number of groups
+
+    Returns
+    -------
+    algo: int
+        algo chosen by CUDNN
+    """
+    return _conv_find_algo(
+        "tvm.contrib.cudnn.conv.backward_data_find_algo",
+        tensor_format,
+        pad,
+        stride,
+        dilation,
+        dy_shape,
+        w_shape,
+        dx_shape,
+        data_dtype,
+        conv_dtype,
+        groups,
     )
-    yshape = np.array(y_shape, dtype=np.int32)
-    func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.find_algo")
-    return func(
+
+
+def conv_backward_filter_find_algo(
+    tensor_format,
+    pad,
+    stride,
+    dilation,
+    dy_shape,
+    x_shape,
+    dw_shape,
+    data_dtype,
+    conv_dtype,
+    groups=1,
+):
+    """Choose the best backward filter algorithm for the given input.
+
+    Paramters
+    ---------
+    tensor_format: int
+        0: CUDNN_TENSOR_NCHW
+        1: CUDNN_TENSOR_NHWC
+        2: CUDNN_TENSOR_NCHW_VECT_C
+    pad: int or list
+        padding
+    stride: int or list
+        stride
+    dilation: int or list
+        dilation
+    dy_shape: list
+        output gradient shape
+    x_shape: list
+        weight shape
+    dw_shape: list
+        wgrad shape
+    data_dtype: str
+        data type
+    conv_dtype: str
+        convolution type
+    groups: int
+        number of groups
+
+    Returns
+    -------
+    algo: int
+        algo chosen by CUDNN
+    """
+    return _conv_find_algo(
+        "tvm.contrib.cudnn.conv.backward_filter_find_algo",
         tensor_format,
-        dims - 2,
-        _get_np_int32_array_handle(pad),
-        _get_np_int32_array_handle(stride),
-        _get_np_int32_array_handle(dilation),
-        _get_np_int32_array_handle(xshape),
-        _get_np_int32_array_handle(wshape),
-        _get_np_int32_array_handle(yshape),
+        pad,
+        stride,
+        dilation,
+        dy_shape,
+        x_shape,
+        dw_shape,
         data_dtype,
         conv_dtype,
         groups,
@@ -414,7 +603,7 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, 
tensor_format, algo, co
             if tensor_format == 1 and conv_dtype == "int32":
                 algo = 1
             else:
-                algo = conv_find_algo(
+                algo = conv_forward_find_algo(
                     tensor_format,
                     pad,
                     stride,
@@ -496,6 +685,189 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, 
tensor_format, algo, co
     )
 
 
+def conv_backward_data(
+    dy,
+    w,
+    pad,
+    stride,
+    dilation,
+    conv_mode,
+    tensor_format,
+    conv_dtype,
+    groups=1,
+    output_padding=(0, 0),
+):
+    """Create a CuDNN extern op that computes the gradient of 2D convolution 
with respect to data.
+
+    Parameters
+    ----------
+    dy: Tensor
+        output gradient
+    w: Tensor
+        convolution weight
+    pad: int or list
+        padding
+    stride: int or list
+        stride
+    dilation: int or list
+        dilation
+    conv_mode: int
+        0: CUDNN_CONVOLUTION
+        1: CUDNN_CROSS_CORRELATION
+    tensor_format: int
+        0: CUDNN_TENSOR_NCHW
+        1: CUDNN_TENSOR_NHWC
+    conv_dtype: str
+        convolution type
+    groups: int
+        the number of groups
+
+    Returns
+    -------
+    dx: Tensor
+        dgrad tensor
+    """
+    dims = len(dy.shape)
+    assert dims == 4
+
+    conv_dtype = dy.dtype if conv_dtype is None else conv_dtype
+    pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, 
stride, dilation)
+
+    assert isinstance(
+        dy.shape[0], tvm.tir.expr.IntImm
+    ), "Dynamic batch is not supported for cudnn conv2d backwad data yet."
+
+    dx_shape = conv_dgrad_shape(
+        tensor_format, pad, stride, dilation, dy.shape, w.shape, output_padding
+    )
+
+    algo = conv_backward_data_find_algo(
+        tensor_format,
+        pad,
+        stride,
+        dilation,
+        list(dy.shape),
+        list(w.shape),
+        dx_shape,
+        dy.dtype,
+        conv_dtype,
+        groups,
+    )
+
+    return te.extern(
+        dx_shape,
+        [dy, w],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.cudnn.conv2d.backward_data",
+            conv_mode,
+            tensor_format,
+            algo,
+            pad[0],
+            pad[1],
+            stride[0],
+            stride[1],
+            dilation[0],
+            dilation[1],
+            ins[0],
+            ins[1],
+            outs[0],
+            conv_dtype,
+            groups,
+        ),
+        name="dx",
+    )
+
+
+def conv_backward_filter(
+    dy, x, kernel_size, pad, stride, dilation, conv_mode, tensor_format, 
conv_dtype, groups=1
+):
+    """Create a CuDNN extern op that computes the gradient of 2D convolution 
with respect to weight.
+
+    Parameters
+    ----------
+    dy: Tensor
+        output gradient
+    x: Tensor
+        input tensor
+    kernel_size: a pair of int
+        The spatial size of the corresponding forward convolution kernel
+    pad: int or list
+        padding
+    stride: int or list
+        stride
+    dilation: int or list
+        dilation
+    conv_mode: int
+        0: CUDNN_CONVOLUTION
+        1: CUDNN_CROSS_CORRELATION
+    tensor_format: int
+        0: CUDNN_TENSOR_NCHW
+        1: CUDNN_TENSOR_NHWC
+    conv_dtype: str
+        convolution type
+    groups: int
+        the number of groups
+
+    Returns
+    -------
+    dw: Tensor
+        wgrad tensor
+    """
+    dims = len(x.shape)
+    assert dims == 4
+
+    conv_dtype = x.dtype if conv_dtype is None else conv_dtype
+    pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, 
stride, dilation)
+    filter_h, filter_w = kernel_size
+
+    x_shape = list(x.shape)
+
+    assert isinstance(
+        x.shape[0], tvm.tir.expr.IntImm
+    ), "Dynamic batch is not supported for cudnn conv2d backwad filter yet."
+
+    if tensor_format == 0:
+        dw_shape = [dy.shape[1], x_shape[1], filter_h, filter_w]
+    else:
+        dw_shape = [dy.shape[3], filter_h, filter_w, x_shape[3]]
+
+    algo = conv_backward_filter_find_algo(
+        tensor_format,
+        pad,
+        stride,
+        dilation,
+        list(dy.shape),
+        list(x.shape),
+        dw_shape,
+        x.dtype,
+        conv_dtype,
+        groups,
+    )
+
+    return te.extern(
+        dw_shape,
+        [dy, x],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.cudnn.conv2d.backward_filter",
+            conv_mode,
+            tensor_format,
+            algo,
+            pad[0],
+            pad[1],
+            stride[0],
+            stride[1],
+            dilation[0],
+            dilation[1],
+            ins[0],
+            ins[1],
+            outs[0],
+            conv_dtype,
+            groups,
+        ),
+        name="dw",
+    )
+
+
 def softmax(x, axis=-1):
     """Compute softmax using CuDNN
 
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 2a941cc..1fa909e 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -1062,6 +1062,10 @@ reg.register_injective_schedule("nn.space_to_batch_nd")
 reg.register_injective_schedule("nn.batch_to_space_nd")
 
 
+reg.register_strategy("nn.conv2d_backward_weight", 
strategy.conv2d_backward_weight_strategy)
+reg.register_pattern("nn.conv2d_backward_weight", 
OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
 @reg.register_legalize("nn.conv2d_backward_weight")
 def legalize_conv2d_backward_weight(attrs, inputs, types):
     """Legalize conv2d_backward_weight op.
diff --git a/python/tvm/relay/op/strategy/cuda.py 
b/python/tvm/relay/op/strategy/cuda.py
index 69579f6..af74514 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -564,6 +564,25 @@ def deformable_conv2d_strategy_cuda(attrs, inputs, 
out_type, target):
     return strategy
 
 
+@conv2d_backward_weight_strategy.register(["cuda"])
+def conv2d_backward_weight_strategy_cuda(attrs, inputs, out_type, target):
+    """conv2d_backward_weight cuda strategy"""
+    strategy = _op.OpStrategy()
+    if target.kind.name == "cuda" and "cudnn" in target.libs:
+        strategy.add_implementation(
+            
wrap_compute_conv2d_backward_weight(topi.cuda.conv2d_backward_weight_cudnn),
+            wrap_topi_schedule(topi.generic.schedule_extern),
+            name="conv2d_backward_weight_strategy.cudnn",
+            plevel=15,
+        )
+    else:
+        raise RuntimeError(
+            "conv2d_backward_weight on cuda is currently only supported with 
cudnn. "
+            "Please run Legalize pass to decompose this op into supported ops."
+        )
+    return strategy
+
+
 @conv2d_transpose_strategy.register(["cuda", "gpu"])
 def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
     """conv2d_transpose cuda strategy"""
@@ -579,6 +598,15 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, 
out_type, target):
         wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw),
         name="conv2d_transpose_nchw.cuda",
     )
+
+    if target.kind.name == "cuda" and "cudnn" in target.libs and 
attrs.kernel_layout == "IOHW":
+        strategy.add_implementation(
+            wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_cudnn),
+            wrap_topi_schedule(topi.generic.schedule_extern),
+            name="conv2d_transpose.cudnn.cuda",
+            plevel=25,
+        )
+    # TODO(masahi): Support conv2d_transpose NHWC.
     return strategy
 
 
diff --git a/python/tvm/relay/op/strategy/generic.py 
b/python/tvm/relay/op/strategy/generic.py
index cc12fa1..abd3e28 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -1841,3 +1841,41 @@ def einsum_strategy(attrs, inputs, out_type, target):
         name="einsum.generic",
     )
     return strategy
+
+
+# conv2d_backward_weight
+def wrap_compute_conv2d_backward_weight(topi_compute):
+    """wrap conv2d_backward_weight topi compute"""
+
+    def _compute_conv2d_backward_weight(attrs, inputs, out_dtype):
+        kernel_size = get_const_tuple(attrs.kernel_size)
+        padding = get_const_tuple(attrs.padding)
+        strides = get_const_tuple(attrs.strides)
+        dilation = get_const_tuple(attrs.dilation)
+        groups = attrs.groups
+        out_dtype = attrs.out_dtype
+        layout = attrs.data_layout
+        out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
+        out = topi_compute(
+            inputs[0],
+            inputs[1],
+            kernel_size,
+            padding,
+            strides,
+            dilation,
+            groups,
+            layout,
+            out_dtype,
+        )
+        return [out]
+
+    return _compute_conv2d_backward_weight
+
+
+@override_native_generic_func("conv2d_backward_weight_strategy")
+def conv2d_backward_weight_strategy(attrs, inputs, out_type, target):
+    """wgrad generic strategy"""
+    raise RuntimeError(
+        "conv2d_backward_weight is currently only supported with cudnn. "
+        "Please run Legalize pass to decompose this op into supported ops."
+    )
diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py
index bd8d7ec..15fcaaa 100644
--- a/python/tvm/topi/cuda/conv2d.py
+++ b/python/tvm/topi/cuda/conv2d.py
@@ -123,3 +123,22 @@ def conv2d_cudnn(
 def schedule_conv2d_cudnn(cfg, outs):
     """Create the schedule for conv2d_cudnn"""
     return generic.schedule_extern(outs)
+
+
+def conv2d_backward_weight_cudnn(
+    dy, x, kernel_size, padding, stride, dilation, groups, layout, output_dtype
+):
+    """Compute conv2d wgrad using CuDNN library"""
+    assert layout in ["NCHW", "NHWC"]
+    return cudnn.conv_backward_filter(
+        dy,
+        x,
+        kernel_size,
+        padding,
+        stride,
+        dilation,
+        conv_mode=1,
+        tensor_format=0 if layout == "NCHW" else 1,
+        conv_dtype=output_dtype,
+        groups=groups,
+    )
diff --git a/python/tvm/topi/cuda/conv2d_transpose_nchw.py 
b/python/tvm/topi/cuda/conv2d_transpose_nchw.py
index 3b70417..36ce3a3 100644
--- a/python/tvm/topi/cuda/conv2d_transpose_nchw.py
+++ b/python/tvm/topi/cuda/conv2d_transpose_nchw.py
@@ -19,6 +19,7 @@
 
 import tvm
 from tvm import te
+from tvm.contrib import cudnn
 from tvm import autotvm
 from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
 from .. import nn
@@ -286,3 +287,10 @@ def schedule_conv2d_transpose_nchw(cfg, outs):
     traverse_inline(s, outs[0].op, _callback)
 
     return s
+
+
+def conv2d_transpose_cudnn(x, w, stride, padding, out_dtype, 
output_padding=(0, 0)):
+    """Compute conv2d_tranpose using cudnn dgrad kernel"""
+    return cudnn.conv_backward_data(
+        x, w, padding, stride, (1, 1), 1, 0, out_dtype, groups=1, 
output_padding=output_padding
+    )
diff --git a/python/tvm/topi/nn/conv2d_transpose.py 
b/python/tvm/topi/nn/conv2d_transpose.py
index 2871699..c408095 100644
--- a/python/tvm/topi/nn/conv2d_transpose.py
+++ b/python/tvm/topi/nn/conv2d_transpose.py
@@ -298,7 +298,6 @@ def conv2d_transpose_legalize(attrs, inputs, types):
     result : tvm.relay.Expr
         The legalized expr
     """
-
     data, kernel = inputs
     kernel_layout = attrs["kernel_layout"]
     if attrs["data_layout"] == "NHWC":
diff --git a/python/tvm/topi/testing/__init__.py 
b/python/tvm/topi/testing/__init__.py
index 75eabff..c3d222c 100644
--- a/python/tvm/topi/testing/__init__.py
+++ b/python/tvm/topi/testing/__init__.py
@@ -75,4 +75,4 @@ from .batch_to_space_nd import batch_to_space_nd_python
 from .nll_loss import nll_loss
 from .dense import dense
 from .searchsorted import searchsorted_ref
-from .conv2d_backcward_weight_python import conv2d_backward_weight_nchw_python
+from .conv2d_backcward_weight_python import conv2d_backward_weight_python
diff --git a/python/tvm/topi/testing/conv2d_backcward_weight_python.py 
b/python/tvm/topi/testing/conv2d_backcward_weight_python.py
index 587cd45..36a6b06 100644
--- a/python/tvm/topi/testing/conv2d_backcward_weight_python.py
+++ b/python/tvm/topi/testing/conv2d_backcward_weight_python.py
@@ -42,7 +42,7 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, 
kernel_size, stride, padding
 
     Returns
     -------
-    b_np : np.ndarray
+    dw_np : np.ndarray
         4-D with shape [num_filter, in_channel, filter_height, filter_width]
 
     """
@@ -74,3 +74,45 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, 
kernel_size, stride, padding
                     dw[k, c, r, s] = acc
 
     return dw
+
+
+def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, 
layout="NCHW"):
+    """Gradient of the conv2d op with respect to weight, in NCHW or NHWC 
layout.
+
+    Parameters
+    ----------
+    dy_np : numpy.ndarray
+        4-D with shape [batch, in_channel, out_height, out_width] for NCHW 
layout
+
+    x_np : numpy.ndarray
+        4-D with shape [batch, in_channel, in_height, in_width] for NCHW layout
+
+    kernel_size : tuple of two ints
+        Height and width of the weight
+
+    stride : tuple of two ints
+        Stride size, or [stride_height, stride_width]
+
+    padding : tuple of two ints
+        Spatial padding, or [pad_h, pad_w]
+
+    layout: string
+        Layout of dy_np and x_np
+
+    Returns
+    -------
+    dw_np : np.ndarray
+        Tensor of shape [num_filter, in_channel, filter_height, filter_width] 
for NCHW layout,
+        [num_filter, filter_height, filter_width, in_channel] for NHWC layout.
+    """
+    if layout == "NCHW":
+        return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, 
stride, padding)
+
+    dw_np_oihw = conv2d_backward_weight_nchw_python(
+        np.transpose(dy_np, [0, 3, 1, 2]),
+        np.transpose(x_np, [0, 3, 1, 2]),
+        kernel_size,
+        stride,
+        padding,
+    )
+    return np.transpose(dw_np_oihw, [0, 2, 3, 1])
diff --git a/python/tvm/topi/testing/conv2d_transpose_python.py 
b/python/tvm/topi/testing/conv2d_transpose_python.py
index a38d8bc..678b5fe 100644
--- a/python/tvm/topi/testing/conv2d_transpose_python.py
+++ b/python/tvm/topi/testing/conv2d_transpose_python.py
@@ -73,7 +73,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, 
padding, output_padding):
             dilated_a_np.shape[2] + bpad_top + bpad_bottom,
             dilated_a_np.shape[3] + bpad_left + bpad_right,
         )
-    )
+    ).astype(a_np.dtype)
     padded_a_np[
         :,
         :,
@@ -83,7 +83,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, 
padding, output_padding):
     # convolution stage
     out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + opad_h
     out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad_w
-    b_np = np.zeros((batch, out_c, out_h, out_w))
+    b_np = np.zeros((batch, out_c, out_h, out_w)).astype(a_np.dtype)
     for n in range(batch):
         for f in range(out_c):
             for c in range(in_c):
diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc
index f1d4eb3..30386bb 100644
--- a/src/relay/op/nn/convolution.cc
+++ b/src/relay/op/nn/convolution.cc
@@ -665,7 +665,6 @@ given the original input data and the output gradient.
     .add_argument("data", "Tensor", "The input tensor.")
     .set_support_level(2)
     .add_type_rel("Conv2DBackwardWeight", Conv2DBackwardWeightRel)
-    .set_attr<TNonComputational>("TNonComputational", true)
     .set_attr<FInferCorrectLayout>("FInferCorrectLayout", 
ConvInferCorrectLayout<Conv2DAttrs>);
 
 }  // namespace relay
diff --git a/src/runtime/contrib/cudnn/conv_backward.cc 
b/src/runtime/contrib/cudnn/conv_backward.cc
new file mode 100644
index 0000000..af190d7
--- /dev/null
+++ b/src/runtime/contrib/cudnn/conv_backward.cc
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file cuDNN kernel calls for backward algorithms.
+ */
+#include <tvm/runtime/data_type.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "cudnn_utils.h"
+
+namespace tvm {
+namespace contrib {
+
+using namespace runtime;
+
+void ConvolutionBackwardData(int mode, int format, int algo, int dims, int 
groups, const int pad[],
+                             const int stride[], const int dilation[], 
DLTensor* dy, DLTensor* w,
+                             DLTensor* dx, const std::string& conv_dtype) {
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  // Set Mode
+  entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
+  SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, 
dx->shape, w->shape,
+                     dy->shape, dy->dtype, conv_dtype);
+  // Set Device
+  entry_ptr->conv_entry.device = dy->device;
+  // Set Algo
+  entry_ptr->conv_entry.bwd_data_algo = 
static_cast<cudnnConvolutionBwdDataAlgo_t>(algo);
+
+  // Set workspace
+  size_t workspace_size = 0;
+  CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize(
+      entry_ptr->handle, entry_ptr->conv_entry.filter_desc, 
entry_ptr->conv_entry.output_desc,
+      entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
+      entry_ptr->conv_entry.bwd_data_algo, &workspace_size));
+  entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
+  CUDNN_CALL(cudnnConvolutionBackwardData(
+      entry_ptr->handle, 
CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
+      entry_ptr->conv_entry.filter_desc, w->data, 
entry_ptr->conv_entry.output_desc, dy->data,
+      entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_data_algo,
+      entry_ptr->conv_entry.workspace, workspace_size,
+      CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type), 
entry_ptr->conv_entry.input_desc,
+      dx->data));
+}
+
+void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], 
const int stride[],
+                          const int dilation[], const int dy_dim[], const int 
w_dim[],
+                          const int dx_dim[], const std::string& data_dtype,
+                          const std::string& conv_dtype, TVMRetValue* ret) {
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  const int full_dims = dims + 2;
+  std::vector<int64_t> dy_dim_int64(full_dims);
+  std::vector<int64_t> w_dim_int64(full_dims);
+  std::vector<int64_t> dx_dim_int64(full_dims);
+  for (int i = 0; i < full_dims; ++i) {
+    dy_dim_int64[i] = dy_dim[i];
+    w_dim_int64[i] = w_dim[i];
+    dx_dim_int64[i] = dx_dim[i];
+  }
+  SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, 
dx_dim_int64.data(),
+                     w_dim_int64.data(), dy_dim_int64.data(), 
String2DLDataType(data_dtype),
+                     conv_dtype);
+
+  int returned_algo_count = 0;
+
+  cudnnConvolutionBwdDataAlgoPerf_t 
perf_results[CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT];
+  CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(
+      entry_ptr->handle, entry_ptr->conv_entry.filter_desc, 
entry_ptr->conv_entry.output_desc,
+      entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
+      CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT, &returned_algo_count, 
perf_results));
+
+  const std::vector<std::string> bwd_data_algo_names{
+      "CUDNN_CONVOLUTION_BWD_DATA_ALGO_0",  // non-deterministic
+      "CUDNN_CONVOLUTION_BWD_DATA_ALGO_1",
+      "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT",
+      "CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING",
+      "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD",
+      "CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED"};
+
+  auto best_algo = perf_results[0].algo;
+  LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd data 
algorithms, choosing "
+            << bwd_data_algo_names[best_algo];
+  for (int i = 0; i < returned_algo_count; ++i) {
+    LOG(INFO) << "\t\t" << i << ") " << 
bwd_data_algo_names[perf_results[i].algo]
+              << " - time: " << perf_results[i].time << " ms"
+              << ", Memory: " << perf_results[i].memory;
+  }
+
+  ret[0] = best_algo;
+}
+
+void ConvolutionBackwardFilter(int mode, int format, int algo, int dims, int 
groups,
+                               const int pad[], const int stride[], const int 
dilation[],
+                               DLTensor* dy, DLTensor* x, DLTensor* dw,
+                               const std::string& conv_dtype) {
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  // Set Mode
+  entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
+  SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, 
x->shape, dw->shape,
+                     dy->shape, x->dtype, conv_dtype);
+  // Set Device
+  entry_ptr->conv_entry.device = x->device;
+  // Set Algo
+  entry_ptr->conv_entry.bwd_filter_algo = 
static_cast<cudnnConvolutionBwdFilterAlgo_t>(algo);
+
+  // Set workspace
+  size_t workspace_size = 0;
+  CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize(
+      entry_ptr->handle, entry_ptr->conv_entry.input_desc, 
entry_ptr->conv_entry.output_desc,
+      entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.filter_desc,
+      entry_ptr->conv_entry.bwd_filter_algo, &workspace_size));
+  entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
+  CUDNN_CALL(cudnnConvolutionBackwardFilter(
+      entry_ptr->handle, 
CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
+      entry_ptr->conv_entry.input_desc, x->data, 
entry_ptr->conv_entry.output_desc, dy->data,
+      entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.bwd_filter_algo,
+      entry_ptr->conv_entry.workspace, workspace_size,
+      CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
+      entry_ptr->conv_entry.filter_desc, dw->data));
+}
+
+void BackwardFilterFindAlgo(int format, int dims, int groups, const int pad[], 
const int stride[],
+                            const int dilation[], const int dy_dim[], const 
int x_dim[],
+                            const int dw_dim[], const std::string& data_dtype,
+                            const std::string& conv_dtype, TVMRetValue* ret) {
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  const int full_dims = dims + 2;
+  std::vector<int64_t> x_dim_int64(full_dims);
+  std::vector<int64_t> dy_dim_int64(full_dims);
+  std::vector<int64_t> dw_dim_int64(full_dims);
+  for (int i = 0; i < full_dims; ++i) {
+    x_dim_int64[i] = x_dim[i];
+    dy_dim_int64[i] = dy_dim[i];
+    dw_dim_int64[i] = dw_dim[i];
+  }
+  SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, 
x_dim_int64.data(),
+                     dw_dim_int64.data(), dy_dim_int64.data(), 
String2DLDataType(data_dtype),
+                     conv_dtype);
+
+  int returned_algo_count = 0;
+
+  cudnnConvolutionBwdFilterAlgoPerf_t 
perf_results[CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT];
+  CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(
+      entry_ptr->handle, entry_ptr->conv_entry.input_desc, 
entry_ptr->conv_entry.output_desc,
+      entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.filter_desc,
+      CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT, &returned_algo_count, 
perf_results));
+
+  const std::vector<std::string> bwd_filter_algo_names{
+      "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0",  // non-deterministic
+      "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1",
+      "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT",
+      "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3",
+      "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED",
+      "CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING",
+  };
+
+  auto best_algo = perf_results[0].algo;
+  LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " bwd filter 
algorithms, choosing "
+            << bwd_filter_algo_names[best_algo];
+  for (int i = 0; i < returned_algo_count; ++i) {
+    LOG(INFO) << "\t\t" << i << ") " << 
bwd_filter_algo_names[perf_results[i].algo]
+              << " - time: " << perf_results[i].time << " ms"
+              << ", Memory: " << perf_results[i].memory;
+  }
+
+  ret[0] = best_algo;
+}
+
+TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data")
+    .set_body([](TVMArgs args, TVMRetValue* ret) {
+      int mode = args[0];
+      int format = args[1];
+      int algo = args[2];
+      int pad_v[2], stride_v[2], dilation_v[2];
+      for (int i = 0; i < 2; i++) {
+        pad_v[i] = args[3 + i];
+        stride_v[i] = args[5 + i];
+        dilation_v[i] = args[7 + i];
+      }
+      DLTensor* dy = args[9];
+      DLTensor* w = args[10];
+      DLTensor* dx = args[11];
+      std::string conv_dtype = args[12];
+      int groups = args[13];
+
+      ConvolutionBackwardData(mode, format, algo, 2, groups, pad_v, stride_v, 
dilation_v, dy, w, dx,
+                              conv_dtype);
+    });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo")
+    .set_body([](TVMArgs args, TVMRetValue* ret) {
+      int format = args[0];
+      int dims = args[1];
+      int* pad = static_cast<int*>(static_cast<void*>(args[2]));
+      int* stride = static_cast<int*>(static_cast<void*>(args[3]));
+      int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
+      int* dy_dim = static_cast<int*>(static_cast<void*>(args[5]));
+      int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
+      int* dx_dim = static_cast<int*>(static_cast<void*>(args[7]));
+      std::string data_dtype = args[8];
+      std::string conv_dtype = args[9];
+      int groups = args[10];
+
+      BackwardDataFindAlgo(format, dims, groups, pad, stride, dilation, 
dy_dim, w_dim, dx_dim,
+                           data_dtype, conv_dtype, ret);
+    });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_filter")
+    .set_body([](TVMArgs args, TVMRetValue* ret) {
+      int mode = args[0];
+      int format = args[1];
+      int algo = args[2];
+      int pad_v[2], stride_v[2], dilation_v[2];
+      for (int i = 0; i < 2; i++) {
+        pad_v[i] = args[3 + i];
+        stride_v[i] = args[5 + i];
+        dilation_v[i] = args[7 + i];
+      }
+      DLTensor* dy = args[9];
+      DLTensor* x = args[10];
+      DLTensor* dw = args[11];
+      std::string conv_dtype = args[12];
+      int groups = args[13];
+
+      ConvolutionBackwardFilter(mode, format, algo, 2, groups, pad_v, 
stride_v, dilation_v, dy, x,
+                                dw, conv_dtype);
+    });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_filter_find_algo")
+    .set_body([](TVMArgs args, TVMRetValue* ret) {
+      int format = args[0];
+      int dims = args[1];
+      int* pad = static_cast<int*>(static_cast<void*>(args[2]));
+      int* stride = static_cast<int*>(static_cast<void*>(args[3]));
+      int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
+      int* dy_dim = static_cast<int*>(static_cast<void*>(args[5]));
+      int* x_dim = static_cast<int*>(static_cast<void*>(args[6]));
+      int* dw_dim = static_cast<int*>(static_cast<void*>(args[7]));
+      std::string data_dtype = args[8];
+      std::string conv_dtype = args[9];
+      int groups = args[10];
+
+      BackwardFilterFindAlgo(format, dims, groups, pad, stride, dilation, 
dy_dim, x_dim, dw_dim,
+                             data_dtype, conv_dtype, ret);
+    });
+
+}  // namespace contrib
+}  // namespace tvm
diff --git a/src/runtime/contrib/cudnn/conv_forward.cc 
b/src/runtime/contrib/cudnn/conv_forward.cc
index b7476e5..f5e5ee8 100644
--- a/src/runtime/contrib/cudnn/conv_forward.cc
+++ b/src/runtime/contrib/cudnn/conv_forward.cc
@@ -18,7 +18,7 @@
  */
 
 /*!
- * \file Use external cudnn utils function
+ * \file cuDNN kernel calls for the forward algorithm.
  */
 #include <tvm/runtime/data_type.h>
 #include <tvm/runtime/device_api.h>
@@ -147,7 +147,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
                          conv_dtype);
     });
 
-TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo")
+TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.forward_find_algo")
     .set_body([](TVMArgs args, TVMRetValue* ret) {
       int format = args[0];
       int dims = args[1];
diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h 
b/src/runtime/contrib/cudnn/cudnn_utils.h
index 89de0e9..426ccfd 100644
--- a/src/runtime/contrib/cudnn/cudnn_utils.h
+++ b/src/runtime/contrib/cudnn/cudnn_utils.h
@@ -67,12 +67,14 @@ inline void GetCudnnStride(int nbdim, const int* dims, int* 
strides) {
 struct ConvEntry {
   cudnnConvolutionDescriptor_t conv_desc;
   cudnnConvolutionMode_t mode{CUDNN_CROSS_CORRELATION};
-  cudnnFilterDescriptor_t filter_desc;
   cudnnDataType_t data_type;
   cudnnTensorFormat_t tensor_format;
   cudnnTensorDescriptor_t input_desc;
+  cudnnFilterDescriptor_t filter_desc;
   cudnnTensorDescriptor_t output_desc;
   cudnnConvolutionFwdAlgo_t fwd_algo;
+  cudnnConvolutionBwdDataAlgo_t bwd_data_algo;
+  cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo;
   // cudnnMathType_t math_type;
   Device device;
   runtime::DeviceAPI* cuda_api;
diff --git a/tests/python/contrib/test_cudnn.py 
b/tests/python/contrib/test_cudnn.py
index bc2cc80..0c39a1a 100644
--- a/tests/python/contrib/test_cudnn.py
+++ b/tests/python/contrib/test_cudnn.py
@@ -236,6 +236,144 @@ def test_softmax():
     verify_softmax_4d((1, 16, 256, 256), "float64", log_softmax=True)
 
 
+def verify_conv2d_backward_data(data_dtype, conv_dtype, tensor_format=0, 
tol=1e-5):
+    batch = 3
+    in_channel = 4
+    out_channel = 16
+    filter_h, filter_w = 3, 3
+    pad_h, pad_w = 1, 1
+    stride_h, stride_w = 1, 1
+    height, width = 32, 32
+
+    if tensor_format == 0:
+        xshape = [batch, in_channel, height, width]
+        wshape = [out_channel, in_channel, filter_h, filter_w]
+        oshape = xshape
+        oshape[1] = out_channel
+        ref_func = tvm.topi.testing.conv2d_transpose_nchw_python
+    else:
+        xshape = [batch, height, width, in_channel]
+        wshape = [out_channel, filter_h, filter_w, in_channel]
+        oshape = xshape
+        oshape[3] = out_channel
+        ref_func = lambda dy_np, w_np, strides, padding, out_pad: 
tvm.topi.testing.conv2d_transpose_nhwc_python(
+            dy_np, np.transpose(w_np, [1, 2, 3, 0]), "HWOI", strides, padding, 
out_pad
+        )
+
+    dy_np = np.random.uniform(-1, 1, oshape).astype(data_dtype)
+    w_np = np.random.uniform(-1, 1, wshape).astype(data_dtype)
+
+    if data_dtype == "float16":
+        dx_np = ref_func(
+            dy_np.astype("float32"),
+            w_np.astype("float32"),
+            (stride_h, stride_w),
+            (pad_h, pad_w),
+            (0, 0),
+        )
+        dx_np = dx_np.astype("float16")
+    else:
+        dx_np = ref_func(dy_np, w_np, (stride_h, stride_w), (pad_h, pad_w), 
(0, 0))
+
+    dy = te.placeholder(oshape, name="dy", dtype=data_dtype)
+    w = te.placeholder(wshape, name="dw", dtype=data_dtype)
+    dx = cudnn.conv_backward_data(
+        dy,
+        w,
+        [pad_h, pad_w],
+        [stride_h, stride_w],
+        [1, 1],
+        conv_mode=1,
+        tensor_format=tensor_format,
+        conv_dtype=conv_dtype,
+        groups=1,
+    )
+
+    s = te.create_schedule(dx.op)
+
+    dev = tvm.cuda(0)
+    f = tvm.build(s, [dy, w, dx], "cuda --host=llvm", 
name="conv2d_backward_data")
+
+    dy = tvm.nd.array(dy_np, dev)
+    w = tvm.nd.array(w_np, dev)
+    dx = tvm.nd.array(dx_np, dev)
+
+    f(dy, w, dx)
+    tvm.testing.assert_allclose(dx.numpy(), dx_np, atol=tol, rtol=tol)
+
+
[email protected]_gpu
+@requires_cudnn
+def test_conv2d_backward_data():
+    verify_conv2d_backward_data("float32", "float32", tensor_format=0, 
tol=1e-5)
+    verify_conv2d_backward_data("float32", "float32", tensor_format=1, 
tol=1e-2)
+    # The scipy convolve function does not support fp16, so the reference will 
be computed with
+    # fp32. Use larger tolerance to be on the safe side (1e-2 also seems 
mostly ok).
+    verify_conv2d_backward_data("float16", "float16", tensor_format=1, 
tol=1e-1)
+
+
+def verify_conv2d_backward_filter(data_dtype, conv_dtype, tensor_format=0, 
tol=1e-5):
+    batch = 3
+    in_channel = 4
+    out_channel = 16
+    filter_h, filter_w = 3, 3
+    pad_h, pad_w = 1, 1
+    stride_h, stride_w = 1, 1
+    height, width = 32, 32
+
+    if tensor_format == 0:
+        x_shape = [batch, in_channel, height, width]
+        dy_shape = [batch, out_channel, height, width]
+    else:
+        x_shape = [batch, height, width, in_channel]
+        dy_shape = [batch, height, width, out_channel]
+
+    x_np = np.random.uniform(-1, 1, x_shape).astype(data_dtype)
+    dy_np = np.random.uniform(-1, 1, dy_shape).astype(data_dtype)
+
+    dw_np = tvm.topi.testing.conv2d_backward_weight_python(
+        dy_np,
+        x_np,
+        (filter_h, filter_w),
+        (stride_h, stride_w),
+        (pad_h, pad_w),
+        "NCHW" if tensor_format == 0 else "NHWC",
+    )
+
+    x = te.placeholder(x_shape, name="x", dtype=data_dtype)
+    dy = te.placeholder(dy_shape, name="dy", dtype=data_dtype)
+    dw = cudnn.conv_backward_filter(
+        dy,
+        x,
+        (filter_h, filter_w),
+        [pad_h, pad_w],
+        [stride_h, stride_w],
+        [1, 1],
+        conv_mode=1,
+        tensor_format=tensor_format,
+        conv_dtype=conv_dtype,
+    )
+
+    s = te.create_schedule(dw.op)
+
+    dev = tvm.cuda(0)
+    f = tvm.build(s, [dy, x, dw], "cuda --host=llvm", 
name="conv2d_backward_filter")
+
+    x = tvm.nd.array(x_np, dev)
+    dy = tvm.nd.array(dy_np, dev)
+    dw = tvm.nd.array(dw_np, dev)
+
+    f(dy, x, dw)
+    tvm.testing.assert_allclose(dw.numpy(), dw_np, atol=tol, rtol=tol)
+
+
[email protected]_gpu
+@requires_cudnn
+def test_conv2d_backward_filter():
+    verify_conv2d_backward_filter("float32", "float32", tensor_format=0, 
tol=1e-4)
+    verify_conv2d_backward_filter("float32", "float32", tensor_format=1, 
tol=1e-4)
+
+
 test_kwargs_default_2d = {
     "tensor_format": 0,
     "pad": [1, 1],
diff --git a/tests/python/relay/test_op_grad_level2.py 
b/tests/python/relay/test_op_grad_level2.py
index 1efdb26..a5fc630 100644
--- a/tests/python/relay/test_op_grad_level2.py
+++ b/tests/python/relay/test_op_grad_level2.py
@@ -233,27 +233,28 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, 
kernel_size, stride, paddin
     dtype = "float32"
     dy = relay.var("dy", shape=dy_shape, dtype=dtype)
     x = relay.var("x", shape=x_shape, dtype=dtype)
-    dw = relay.nn.conv2d_backward_weight(
-        dy, x, strides=stride, padding=padding, kernel_size=kernel_size
+    dw_func = relay.Function(
+        [dy, x],
+        relay.nn.conv2d_backward_weight(
+            dy, x, strides=stride, padding=padding, kernel_size=kernel_size
+        ),
     )
-    dw_func = relay.Function([dy, x], dw)
     dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize())
 
-    target = "llvm"
-    dev = tvm.device(target, 0)
-    dy_np = np.random.randn(*dy_shape).astype(dtype)
-    x_np = np.random.randn(*x_shape).astype(dtype)
+    for dw, target in [(dw_func_legalized, "llvm"), (dw_func, "cuda 
-libs=cudnn")]:
+        if "cudnn" in target and not tvm.contrib.cudnn.exists():
+            continue
 
-    dw_np = (
-        relay.create_executor(device=dev, target=target)
-        .evaluate(dw_func_legalized)(dy_np, x_np)
-        .numpy()
-    )
-    ref_dw_np = tvm.topi.testing.conv2d_backward_weight_nchw_python(
-        dy_np, x_np, kernel_size, stride, padding
-    )
+        dev = tvm.device(target, 0)
+        dy_np = np.random.randn(*dy_shape).astype(dtype)
+        x_np = np.random.randn(*x_shape).astype(dtype)
+
+        dw_np = relay.create_executor(device=dev, 
target=target).evaluate(dw)(dy_np, x_np).numpy()
+        ref_dw_np = tvm.topi.testing.conv2d_backward_weight_python(
+            dy_np, x_np, kernel_size, stride, padding
+        )
 
-    np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4)
+        np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4)
 
 
 def test_conv2d_backward_weight():
diff --git a/tests/python/relay/test_op_level2.py 
b/tests/python/relay/test_op_level2.py
index db712be..6d428bf 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -24,7 +24,7 @@ import tvm
 import tvm.testing
 import tvm.topi.testing
 from tvm import autotvm, relay, te
-from tvm.contrib import utils
+from tvm.contrib import utils, cudnn
 from tvm.ir.module import IRModule
 from tvm.relay import transform
 from tvm.relay.testing import run_infer_type
@@ -838,10 +838,10 @@ def test_conv2d_transpose_infer_type():
 @tvm.testing.uses_gpu
 def test_conv2d_transpose_nchw_run():
     k_layouts = {"OIHW": (10, 3, 3, 3), "IOHW": (3, 10, 3, 3)}
+    output_padding = (1, 1)
 
     for k_layout, kshape in k_layouts.items():
         dshape = (1, 3, 18, 18)
-        oshape = (1, 10, 36, 36)
         x = relay.var("x", shape=dshape)
         w = relay.var("w")
         y = relay.nn.conv2d_transpose(
@@ -851,7 +851,7 @@ def test_conv2d_transpose_nchw_run():
             kernel_size=(3, 3),
             strides=(2, 2),
             padding=(1, 1),
-            output_padding=(1, 1),
+            output_padding=output_padding,
             kernel_layout=k_layout,
             data_layout="NCHW",
         )
@@ -866,9 +866,16 @@ def test_conv2d_transpose_nchw_run():
         else:
             kernel_iohw = kernel
 
-        ref_res = tvm.topi.testing.conv2d_transpose_nchw_python(data, 
kernel_iohw, 2, 1, (1, 1))
+        ref_res = tvm.topi.testing.conv2d_transpose_nchw_python(
+            data, kernel_iohw, 2, 1, output_padding
+        )
 
-        for target, dev in tvm.testing.enabled_targets():
+        enabled_targets = tvm.testing.enabled_targets()
+
+        if cudnn.exists() and k_layout == "IOHW":
+            enabled_targets.append(("cuda -libs=cudnn", tvm.cuda(0)))
+
+        for target, dev in enabled_targets:
             op_res1 = relay.create_executor("graph", device=dev, 
target=target).evaluate(func)(
                 data, kernel
             )

Reply via email to