Qianshui-Jiang commented on code in PR #11571:
URL: https://github.com/apache/tvm/pull/11571#discussion_r891880282


##########
python/tvm/contrib/mkldnn.py:
##########
@@ -50,3 +51,107 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs):
         name="C",
         **kwargs,
     )
+
+
+def dnnl_conv2d(
+    src,
+    weights,
+    stride,
+    padding,
+    dilation,
+    groups,
+    channel_last=False,
+    out_dtype="float32",
+    **kwargs,
+):
+    """Convolution operator in NCHW layout.
+
+    Parameters
+    ----------
+    src : tvm.te.Tensor
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    weights : tvm.te.Tensor
+        4-D with shape [num_filter, in_channel, filter_height, filter_width]
+
+    stride : int or a list/tuple of two ints
+        Stride size, or [stride_height, stride_width]
+
+    padding : int or a list/tuple of 2 or 4 ints
+        padding size, or
+        [pad_height, pad_width] for 2 ints, or
+        [pad_top, pad_left, pad_bottom, pad_right] for 4 ints
+
+    dilation: int or a list/tuple of two ints
+        dilation size, or [dilation_height, dilation_width]
+
+    groups: str
+        input data layout: NCHW or NHWC
+
+    channel_last: bool
+        chose if input/output data format is in channel_last format(NHWC) or
+        in plain format(NCHW)
+
+    out_dtype: str
+        output datatype: now only support float32
+
+    Returns
+    -------
+    Output : tvm.te.Tensor
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+
+    assert isinstance(stride, int) or len(stride) == 2
+    assert isinstance(dilation, int) or len(dilation) == 2
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    if channel_last:
+        batch, in_height, in_width, _ = src.shape
+        kernel_h, kernel_w, _, num_filter = weights.shape
+    else:
+        batch, _, in_height, in_width = src.shape
+        num_filter, _, kernel_h, kernel_w = weights.shape
+
+    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
+    out_channel = num_filter
+    out_height = (in_height - dilated_kernel_h + pad_top + pad_down) // 
stride_h + 1
+    out_width = (in_width - dilated_kernel_w + pad_left + pad_right) // 
stride_w + 1
+
+    if channel_last:
+        out_shape = (batch, out_height, out_width, out_channel)
+    else:
+        out_shape = (batch, out_channel, out_height, out_width)
+
+    return te.extern(
+        out_shape,
+        [src, weights],
+        lambda ins, outs: tvm.tir.call_packed(
+            "tvm.contrib.mkldnn.conv2d",

Review Comment:
   Yep, that would be more concise, 💯 
   I'll try and commit later.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to