This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1cf0c0a5bf [CUDNN] Add partitioning support for fused conv2d+bias+act 
(#10997)
1cf0c0a5bf is described below

commit 1cf0c0a5bfa6b0c61ce253142b66f6235d694e07
Author: Matthew Barrett <[email protected]>
AuthorDate: Thu Apr 14 09:50:17 2022 +0100

    [CUDNN] Add partitioning support for fused conv2d+bias+act (#10997)
    
    cuDNN has kernel support for the pattern conv2d+bias+act,
    although as of v8 only relu is supported as the activation.
---
 python/tvm/relay/op/contrib/cudnn.py      | 79 +++++++++++++++++++++++++++----
 src/runtime/contrib/cudnn/conv_forward.cc | 62 ++++++++++++++++++++++++
 src/runtime/contrib/cudnn/cudnn_utils.cc  |  4 ++
 src/runtime/contrib/cudnn/cudnn_utils.h   |  2 +
 tests/python/contrib/test_cudnn.py        | 51 ++++++++++++++++++--
 5 files changed, 186 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relay/op/contrib/cudnn.py 
b/python/tvm/relay/op/contrib/cudnn.py
index 9714a0b87d..e3c256f7e3 100644
--- a/python/tvm/relay/op/contrib/cudnn.py
+++ b/python/tvm/relay/op/contrib/cudnn.py
@@ -16,7 +16,7 @@
 # under the License.
 # pylint: disable=unused-argument
 """cuDNN Relay integration."""
-from typing import Callable, List, Tuple, Dict, Optional
+from typing import Callable, List, Tuple
 
 import tvm
 import tvm.ir
@@ -24,7 +24,6 @@ from tvm import relay
 from tvm import te
 from tvm.relay import transform
 from tvm.contrib import cudnn
-from tvm.relay.build_module import bind_params_by_name
 
 from ...dataflow_pattern import is_op, wildcard
 from .te_target import lower_composite, relay_to_runtime
@@ -34,25 +33,19 @@ from .register import register_pattern_table
 tvm._ffi.register_func("relay.ext.cudnn", relay_to_runtime(tvm.target.cuda()))
 
 
-def partition_for_cudnn(
-    mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
-) -> tvm.IRModule:
+def partition_for_cudnn(mod: tvm.IRModule) -> tvm.IRModule:
     """Partition the graph to offload for cuDNN.
 
     Parameters
     ----------
     mod : tvm.IRModule
         The module to partition.
-    params : Optional[Dict[str, tvm.runtime.NDArray]]
-        Constant input parameters.
 
     Returns
     -------
     tvm.IRModule
         The partitioned module.
     """
-    if params:
-        mod["main"] = bind_params_by_name(mod["main"], params)
 
     seq = tvm.transform.Sequential(
         [
@@ -82,6 +75,12 @@ def pattern_table() -> List[Tuple[str, relay.Pattern, 
Callable[[relay.Call], boo
         """Create pattern for conv2d."""
         return is_op("nn.conv2d")(wildcard(), wildcard())
 
+    def conv2d_bias_act_pattern() -> relay.Pattern:
+        """Create pattern for fused conv2d+bias+activation."""
+        conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
+        bias = is_op("nn.bias_add")(conv2d, wildcard())
+        return bias.optional(is_op("nn.relu"))
+
     def check_softmax(matched: relay.Call) -> bool:
         """Check if softmax is supported by cuDNN."""
         if matched.args[0].checked_type.dtype not in ["float64", "float32", 
"float16"]:
@@ -115,9 +114,13 @@ def pattern_table() -> List[Tuple[str, relay.Pattern, 
Callable[[relay.Call], boo
 
         return True
 
+    def check_conv2d_bias_act(matched: relay.Call) -> bool:
+        return True
+
     return [
         ("cudnn.softmax", softmax_pattern(), check_softmax),
         ("cudnn.log_softmax", log_softmax_pattern(), check_log_softmax),
+        ("cudnn.conv2d_bias_act", conv2d_bias_act_pattern(), 
check_conv2d_bias_act),
         ("cudnn.conv2d", conv2d_pattern(), check_conv2d),
     ]
 
@@ -134,6 +137,64 @@ def _lower_log_softmax(op: relay.Call, inputs: 
List[te.Tensor]) -> te.Tensor:
     return cudnn.log_softmax(inputs[0], axis=op.attrs["axis"])
 
 
+@lower_composite("cudnn.conv2d_bias_act")
+def _lower_conv2d_bias_act(op: relay.Call, inputs: List[te.Tensor]) -> 
te.Tensor:
+    """Lower a fused conv2d+bias+activation using cuDNN."""
+    conv_dtype = op.checked_type.dtype
+    if op.op.name == "nn.relu":
+        activation_mode = 1  # Relu
+        conv2d = op.args[0].args[0]
+    else:
+        activation_mode = 5  # Identity
+        conv2d = op.args[0]
+
+    conv_mode = 1
+    tensor_format = 0
+    algo = 1
+    pad = conv2d.attrs["padding"]
+    strides = conv2d.attrs["strides"]
+    dilation = conv2d.attrs["dilation"]
+    groups = conv2d.attrs["groups"]
+
+    oshape = cudnn.conv_output_shape(
+        tensor_format,
+        pad,
+        strides,
+        dilation,
+        inputs[0].shape,
+        inputs[1].shape,
+        inputs[0].dtype,
+        conv_dtype,
+        groups,
+    )
+
+    return te.extern(
+        oshape,
+        inputs,
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.cudnn.conv2d+bias+act.forward",
+            conv_mode,
+            tensor_format,
+            algo,
+            pad[0],
+            pad[1],
+            strides[0],
+            strides[1],
+            dilation[0],
+            dilation[1],
+            activation_mode,
+            0,
+            ins[0],
+            ins[1],
+            ins[2],
+            outs[0],
+            conv_dtype,
+            groups,
+        ),
+        name="y",
+    )
+
+
 @lower_composite("cudnn.conv2d")
 def _lower_conv2d(op: relay.Call, inputs: List[te.Tensor]) -> te.Tensor:
     """Lower a conv2d using cuDNN."""
diff --git a/src/runtime/contrib/cudnn/conv_forward.cc 
b/src/runtime/contrib/cudnn/conv_forward.cc
index f5e5ee889c..626d356da4 100644
--- a/src/runtime/contrib/cudnn/conv_forward.cc
+++ b/src/runtime/contrib/cudnn/conv_forward.cc
@@ -60,6 +60,44 @@ void ConvolutionForward(int mode, int format, int algo, int 
dims, int groups, co
       entry_ptr->conv_entry.output_desc, y->data));
 }
 
+void ConvolutionBiasActivationForward(int mode, int format, int algo, int 
dims, int groups, int act,
+                                      double coef, const int pad[], const int 
stride[],
+                                      const int dilation[], DLTensor* x, 
DLTensor* w, DLTensor* y,
+                                      DLTensor* bias, const std::string& 
conv_dtype) {
+  CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+  // Set Mode
+  entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
+  
CUDNN_CALL(cudnnSetActivationDescriptor(entry_ptr->conv_entry.activation_desc,
+                                          
static_cast<cudnnActivationMode_t>(act),
+                                          
cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, coef));
+  CUDNN_CALL(cudnnSetTensor4dDescriptor(
+      entry_ptr->conv_entry.bias_desc, entry_ptr->conv_entry.tensor_format,
+      CuDNNDataType::DLTypeToCuDNNType(bias->dtype), 1, 
static_cast<int>(w->shape[0]), 1, 1));
+
+  SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, 
x->shape, w->shape,
+                     y->shape, x->dtype, conv_dtype);
+  // Set Device
+  entry_ptr->conv_entry.device = x->device;
+  // Set Algo
+  entry_ptr->conv_entry.fwd_algo = 
static_cast<cudnnConvolutionFwdAlgo_t>(algo);
+
+  // Set workspace
+  size_t workspace_size = 0;
+  CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
+      entry_ptr->handle, entry_ptr->conv_entry.input_desc, 
entry_ptr->conv_entry.filter_desc,
+      entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
+      entry_ptr->conv_entry.fwd_algo, &workspace_size));
+  entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
+  CUDNN_CALL(cudnnConvolutionBiasActivationForward(
+      entry_ptr->handle, 
CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
+      entry_ptr->conv_entry.input_desc, x->data, 
entry_ptr->conv_entry.filter_desc, w->data,
+      entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo,
+      entry_ptr->conv_entry.workspace, workspace_size,
+      CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
+      entry_ptr->conv_entry.output_desc, y->data, 
entry_ptr->conv_entry.bias_desc, bias->data,
+      entry_ptr->conv_entry.activation_desc, 
entry_ptr->conv_entry.output_desc, y->data));
+}
+
 void FindAlgo(int format, int dims, int groups, const int pad[], const int 
stride[],
               const int dilation[], const int x_dim[], const int w_dim[], 
const int y_dim[],
               const std::string& data_dtype, const std::string& conv_dtype, 
TVMRetValue* ret) {
@@ -126,6 +164,30 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.forward")
                          conv_dtype);
     });
 
+TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d+bias+act.forward")
+    .set_body([](TVMArgs args, TVMRetValue* ret) {
+      int mode = args[0];
+      int format = args[1];
+      int algo = args[2];
+      int pad_v[2], stride_v[2], dilation_v[2];
+      for (int i = 0; i < 2; i++) {
+        pad_v[i] = args[3 + i];
+        stride_v[i] = args[5 + i];
+        dilation_v[i] = args[7 + i];
+      }
+      int act = args[9];
+      double coef = args[10];
+      DLTensor* x = args[11];
+      DLTensor* w = args[12];
+      DLTensor* bias = args[13];
+      DLTensor* y = args[14];
+      std::string conv_dtype = args[15];
+      int groups = args[16];
+
+      ConvolutionBiasActivationForward(mode, format, algo, 2, groups, act, 
coef, pad_v, stride_v,
+                                       dilation_v, x, w, y, bias, conv_dtype);
+    });
+
 TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
     .set_body([](TVMArgs args, TVMRetValue* ret) {
       int mode = args[0];
diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc 
b/src/runtime/contrib/cudnn/cudnn_utils.cc
index e39c47339c..68d5902c06 100644
--- a/src/runtime/contrib/cudnn/cudnn_utils.cc
+++ b/src/runtime/contrib/cudnn/cudnn_utils.cc
@@ -140,6 +140,8 @@ ConvEntry::ConvEntry() {
   CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc));
   CUDNN_CALL(cudnnCreateTensorDescriptor(&input_desc));
   CUDNN_CALL(cudnnCreateTensorDescriptor(&output_desc));
+  CUDNN_CALL(cudnnCreateTensorDescriptor(&bias_desc));
+  CUDNN_CALL(cudnnCreateActivationDescriptor(&activation_desc));
 }
 
 ConvEntry::~ConvEntry() {
@@ -147,6 +149,8 @@ ConvEntry::~ConvEntry() {
   CUDNN_CALL(cudnnDestroyConvolutionDescriptor(conv_desc));
   CUDNN_CALL(cudnnDestroyTensorDescriptor(input_desc));
   CUDNN_CALL(cudnnDestroyTensorDescriptor(output_desc));
+  CUDNN_CALL(cudnnDestroyTensorDescriptor(bias_desc));
+  CUDNN_CALL(cudnnDestroyActivationDescriptor(activation_desc));
   CleanWorkspace();
 }
 
diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h 
b/src/runtime/contrib/cudnn/cudnn_utils.h
index 426ccfdf37..871fb35dd4 100644
--- a/src/runtime/contrib/cudnn/cudnn_utils.h
+++ b/src/runtime/contrib/cudnn/cudnn_utils.h
@@ -71,6 +71,8 @@ struct ConvEntry {
   cudnnTensorFormat_t tensor_format;
   cudnnTensorDescriptor_t input_desc;
   cudnnFilterDescriptor_t filter_desc;
+  cudnnTensorDescriptor_t bias_desc;
+  cudnnActivationDescriptor_t activation_desc;
   cudnnTensorDescriptor_t output_desc;
   cudnnConvolutionFwdAlgo_t fwd_algo;
   cudnnConvolutionBwdDataAlgo_t bwd_data_algo;
diff --git a/tests/python/contrib/test_cudnn.py 
b/tests/python/contrib/test_cudnn.py
index 8ca3df343d..cdbe424710 100644
--- a/tests/python/contrib/test_cudnn.py
+++ b/tests/python/contrib/test_cudnn.py
@@ -461,10 +461,12 @@ def _verify_cudnn_relay(expr):
     for param in func.params:
         shape = [int(x) for x in param.checked_type.shape]
         input_data.append(
-            (param.name_hint, np.random.uniform(0, 32, 
size=shape).astype(param.checked_type.dtype))
+            (
+                param.name_hint,
+                np.random.uniform(-32, 32, 
size=shape).astype(param.checked_type.dtype),
+            )
         )
 
-    # Test against CPU reference
     cuda_config = (tvm.target.cuda(), tvm.cuda(), cudnn_mod)
     cpu_config = (tvm.target.Target("llvm"), tvm.cpu(), mod)
     outputs = []
@@ -484,7 +486,8 @@ def _verify_cudnn_relay(expr):
     tvm.testing.assert_allclose(
         outputs[0],
         outputs[1],
-        rtol=1e-2,
+        rtol=1e-3,
+        atol=30,
     )
 
 
@@ -577,5 +580,47 @@ def test_relay_cudnn_conv2d(n, h, w, ci, co, kh, kw, 
strides, dilation, padding,
     _verify_cudnn_relay(conv2d)
 
 
[email protected]_cuda
[email protected](
+    "n,h,w,ci,co,groups",
+    [
+        (1, 16, 20, 8, 16, 1),
+        (10, 17, 19, 16, 8, 4),
+    ],
+)
[email protected](
+    "kh,kw,padding,strides,dilation,dtype",
+    [
+        (1, 1, (3, 1, 3, 1), (1, 1), (1, 1), "float32"),
+        (3, 3, (1, 2), (2, 1), (2, 2), "float16"),
+        (7, 2, (0, 0), (3, 3), (1, 2), "float64"),
+    ],
+)
[email protected]("activation", [True, False])
+def test_relay_cudnn_conv2d_bias_act(
+    n, h, w, ci, co, kh, kw, strides, dilation, padding, groups, dtype, 
activation
+):
+    data = tvm.relay.var("data", tvm.relay.TensorType((n, ci, h, w), dtype))
+    weight = tvm.relay.var("weight", tvm.relay.TensorType((co, ci // groups, 
kh, kw), dtype))
+    bias = relay.var("bias", relay.TensorType((co,), dtype))
+    conv2d = relay.op.nn.conv2d(
+        data,
+        weight,
+        groups=groups,
+        channels=co,
+        kernel_size=(kh, kw),
+        strides=strides,
+        dilation=dilation,
+        padding=padding,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+    )
+    out = relay.op.nn.bias_add(conv2d, bias)
+    if activation:
+        out = relay.op.nn.relu(out)
+
+    _verify_cudnn_relay(out)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main(sys.argv))

Reply via email to