This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a571bfbbca [TOPI] Allow conv definition to have custom kernel layout 
(#11936)
a571bfbbca is described below

commit a571bfbbcab47be5b6573873de9acde353b99d14
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Jul 13 20:42:38 2022 -0700

    [TOPI] Allow conv definition to have custom kernel layout (#11936)
    
    * [TOPI] Allow conv definition to have custom kernel layout
    
    * add tests
    
    * fix
    
    * fix
---
 python/tvm/relay/op/strategy/arm_cpu.py           |  10 +-
 python/tvm/relay/op/strategy/cuda.py              |  11 +-
 python/tvm/relay/op/strategy/generic.py           |  15 ++-
 python/tvm/relay/op/strategy/hls.py               |   2 +-
 python/tvm/relay/op/strategy/intel_graphics.py    |   8 +-
 python/tvm/relay/op/strategy/rocm.py              |   2 +-
 python/tvm/relay/op/strategy/x86.py               |  10 +-
 python/tvm/topi/nn/conv1d.py                      |  35 +++++--
 python/tvm/topi/nn/conv2d.py                      | 117 ++++++++++++++--------
 python/tvm/topi/nn/conv3d.py                      |   3 +-
 tests/python/integration/test_winograd_nnpack.py  |   2 +-
 tests/python/topi/python/test_topi_conv2d_nhwc.py |  31 +++++-
 vta/python/vta/top/op.py                          |   2 +-
 13 files changed, 175 insertions(+), 73 deletions(-)

diff --git a/python/tvm/relay/op/strategy/arm_cpu.py 
b/python/tvm/relay/op/strategy/arm_cpu.py
index 4c5af610d7..7c48b09ff0 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -276,13 +276,15 @@ def conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, 
out_type, target):
     data, kernel = inputs
     if topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype):
         strategy.add_implementation(
-            wrap_compute_conv2d(topi.arm_cpu.conv2d_NCHWc_int8, True, True),
+            wrap_compute_conv2d(
+                topi.arm_cpu.conv2d_NCHWc_int8, need_data_layout=True, 
need_out_layout=True
+            ),
             wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NCHWc_int8),
             name="conv2d_NCHWc_int8.arm_cpu",
         )
     else:
         strategy.add_implementation(
-            wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True),
+            wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, 
need_out_layout=True),
             wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
             name="conv2d_NCHWc.x86",
         )
@@ -294,7 +296,9 @@ def depthwise_conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, 
out_type, target):
     """depthwise_conv2d_NCHWc adopted from x86"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True),
+        wrap_compute_conv2d(
+            topi.x86.depthwise_conv2d_NCHWc, need_data_layout=True, 
need_out_layout=True
+        ),
         wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc),
         name="depthwise_conv2d_NCHWc.x86",
     )
diff --git a/python/tvm/relay/op/strategy/cuda.py 
b/python/tvm/relay/op/strategy/cuda.py
index 9c4a896d57..e3c74e15c2 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -316,10 +316,19 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
         ):
             assert kernel_layout == "OIHW4o4i"
             strategy.add_implementation(
-                wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
+                wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, 
need_data_layout=True),
                 wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
                 name="conv2d_NCHWc_int8.cuda",
             )
+        elif is_auto_scheduler_enabled():
+            strategy.add_implementation(
+                wrap_compute_conv2d(
+                    topi.nn.conv, need_data_layout=True, 
need_kernel_layout=True, has_groups=True
+                ),
+                naive_schedule,
+                name="conv2d.cuda",
+                plevel=15,
+            )
         elif target.kind.name == "cuda" and "cudnn" not in target.libs:
             # No TVM native kernel applicable
             raise RuntimeError("Unsupported conv2d layout {} for 
CUDA".format(layout))
diff --git a/python/tvm/relay/op/strategy/generic.py 
b/python/tvm/relay/op/strategy/generic.py
index 4ff7490b89..6074b0a69c 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -223,7 +223,9 @@ get_meta_schedule_original_shape = _ffi.get_global_func(
 # conv2d
 def wrap_compute_conv2d(
     topi_compute,
+    *,
     need_data_layout=False,
+    need_kernel_layout=False,
     need_out_layout=False,
     has_groups=False,
     need_auto_scheduler_layout=False,
@@ -236,6 +238,7 @@ def wrap_compute_conv2d(
         strides = get_const_tuple(attrs.strides)
         dilation = get_const_tuple(attrs.dilation)
         data_layout = attrs.get_str("data_layout")
+        kernel_layout = attrs.get_str("kernel_layout")
         out_layout = attrs.get_str("out_layout")
         out_dtype = attrs.out_dtype
         out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
@@ -244,6 +247,8 @@ def wrap_compute_conv2d(
             args.append(attrs.groups)
         if need_data_layout:
             args.append(data_layout)
+        if need_kernel_layout:
+            args.append(kernel_layout)
         if need_out_layout:
             args.append(out_layout)
         args.append(out_dtype)
@@ -340,13 +345,15 @@ def conv2d_NCHWc_strategy(attrs, inputs, out_type, 
target):
     strategy = _op.OpStrategy()
     if inputs[0].dtype == "int8" or inputs[0].dtype == "uint8":
         strategy.add_implementation(
-            wrap_compute_conv2d(topi.nn.conv2d_NCHWc_int8, True, True),
+            wrap_compute_conv2d(
+                topi.nn.conv2d_NCHWc_int8, need_data_layout=True, 
need_out_layout=True
+            ),
             wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc_int8),
             name="conv2d_NCHWc_int8.generic",
         )
     else:
         strategy.add_implementation(
-            wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
+            wrap_compute_conv2d(topi.nn.conv2d_NCHWc, need_data_layout=True, 
need_out_layout=True),
             wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc),
             name="conv2d_NCHWc.generic",
         )
@@ -360,7 +367,9 @@ def depthwise_conv2d_NCHWc_strategy(attrs, inputs, 
out_type, target):
     logger.warning("depthwise_conv2d_NCHWc is not optimized for this 
platform.")
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_conv2d(topi.nn.depthwise_conv2d_NCHWc, True, True),
+        wrap_compute_conv2d(
+            topi.nn.depthwise_conv2d_NCHWc, need_data_layout=True, 
need_out_layout=True
+        ),
         wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_NCHWc),
         name="depthwise_conv2d_NCHWc.generic",
     )
diff --git a/python/tvm/relay/op/strategy/hls.py 
b/python/tvm/relay/op/strategy/hls.py
index 1eebbd36b8..4a682066ca 100644
--- a/python/tvm/relay/op/strategy/hls.py
+++ b/python/tvm/relay/op/strategy/hls.py
@@ -137,7 +137,7 @@ def conv2d_NCHWc_strategy_hls(attrs, inputs, out_type, 
target):
     """conv2d_NCHWc hls strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
+        wrap_compute_conv2d(topi.nn.conv2d_NCHWc, need_data_layout=True, 
need_out_layout=True),
         wrap_topi_schedule(topi.hls.schedule_conv2d_NCHWc),
         name="conv2d_NCHWc.hls",
     )
diff --git a/python/tvm/relay/op/strategy/intel_graphics.py 
b/python/tvm/relay/op/strategy/intel_graphics.py
index a2de49c557..115a711144 100644
--- a/python/tvm/relay/op/strategy/intel_graphics.py
+++ b/python/tvm/relay/op/strategy/intel_graphics.py
@@ -44,7 +44,9 @@ def conv2d_strategy_intel_graphics(attrs, inputs, out_type, 
target):
             # conv2d_NCHWc won't work without alter op layout pass
             # TODO(@Laurawly): fix this
             strategy.add_implementation(
-                wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, 
True),
+                wrap_compute_conv2d(
+                    topi.intel_graphics.conv2d_NCHWc, need_data_layout=True, 
need_out_layout=True
+                ),
                 wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
                 name="conv2d_NCHWc.intel_graphics",
                 plevel=5,
@@ -71,7 +73,9 @@ def conv2d_NCHWc_strategy_intel_graphics(attrs, inputs, 
out_type, target):
     """conv2d_NCHWc intel_graphics strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
+        wrap_compute_conv2d(
+            topi.intel_graphics.conv2d_NCHWc, need_data_layout=True, 
need_out_layout=True
+        ),
         wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
         name="conv2d_NCHWc.intel_graphics",
     )
diff --git a/python/tvm/relay/op/strategy/rocm.py 
b/python/tvm/relay/op/strategy/rocm.py
index 6e91101826..89cac0db4a 100644
--- a/python/tvm/relay/op/strategy/rocm.py
+++ b/python/tvm/relay/op/strategy/rocm.py
@@ -44,7 +44,7 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
         and padding[1] == padding[3]
     ):
         strategy.add_implementation(
-            wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
+            wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, 
need_data_layout=True),
             wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
             name="conv2d_nchw_miopen.rocm",
             plevel=50,
diff --git a/python/tvm/relay/op/strategy/x86.py 
b/python/tvm/relay/op/strategy/x86.py
index abbc9d9a4c..17474020ee 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -269,13 +269,15 @@ def conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, 
target):
     data, kernel = inputs
     if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype):
         strategy.add_implementation(
-            wrap_compute_conv2d(topi.x86.conv2d_NCHWc_int8, True, True),
+            wrap_compute_conv2d(
+                topi.x86.conv2d_NCHWc_int8, need_data_layout=True, 
need_out_layout=True
+            ),
             wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc_int8),
             name="conv2d_NCHWc_int8.x86",
         )
     else:
         strategy.add_implementation(
-            wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True),
+            wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, 
need_out_layout=True),
             wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
             name="conv2d_NCHWc.x86",
         )
@@ -287,7 +289,9 @@ def depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, 
out_type, target):
     """depthwise_conv2d x86 strategy"""
     strategy = _op.OpStrategy()
     strategy.add_implementation(
-        wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True),
+        wrap_compute_conv2d(
+            topi.x86.depthwise_conv2d_NCHWc, need_data_layout=True, 
need_out_layout=True
+        ),
         wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc),
         name="depthwise_conv2d_NCHWc.x86",
     )
diff --git a/python/tvm/topi/nn/conv1d.py b/python/tvm/topi/nn/conv1d.py
index 0a1efa3565..ee388b4297 100644
--- a/python/tvm/topi/nn/conv1d.py
+++ b/python/tvm/topi/nn/conv1d.py
@@ -19,18 +19,27 @@
 from .conv2d import conv
 
 
-def conv1d(data, kernel, strides=1, padding="VALID", dilation=1, layout="NCW", 
out_dtype=None):
+def conv1d(
+    data,
+    kernel,
+    strides=1,
+    padding="VALID",
+    dilation=1,
+    data_layout="NCW",
+    kernel_layout="",
+    out_dtype=None,
+):
     """1D convolution forward operator.
 
     Parameters
     ----------
     data : tvm.te.Tensor
-        3-D input shape [batch, in_channel, in_width] for layout == 'NCW'
-        and [batch, in_width, in_channel] for layout == 'NWC'
+        3-D input shape [batch, in_channel, in_width] for data_layout == 'NCW'
+        and [batch, in_width, in_channel] for data_layout == 'NWC'
 
     kernel : tvm.te.Tensor
-        3-D kernel with shape [num_filter, in_channel, filter_size] for layout 
== 'NCW'
-        and [filter_size, in_channel, num_filter] for layout == 'NWC'
+        3-D kernel with shape [num_filter, in_channel, filter_size] for 
kernel_layout == 'OIW'
+        and [filter_size, in_channel, num_filter] for kernel_layout == 'WIO'
 
     strides : int or tuple
         The spatial stride along width
@@ -41,23 +50,27 @@ def conv1d(data, kernel, strides=1, padding="VALID", 
dilation=1, layout="NCW", o
     dilation : int or tuple
         Dilation rate if convolution should be dilated.
 
-    layout : str
+    data_layout : str
         How input data is laid out, must be one of ['NCW', 'NWC']
 
+    kernel_layout: Optiona[str]
+        The layout of the kernel. If unspecified, use default layout. "OIW" if 
data_layout == "NCW",
+        "WIO" if data_layout == "NWC".
+
     out_dtype : str
         The output data type. If None then output is same type as input.
     """
-    return conv(data, kernel, strides, padding, dilation, 1, layout, out_dtype)
+    return conv(data, kernel, strides, padding, dilation, 1, data_layout, 
kernel_layout, out_dtype)
 
 
 def conv1d_nwc(data, kernel, strides=1, padding="VALID", dilation=1, 
out_dtype=None):
     """1D convolution in NWC layout. See :py:func:`conv` for details on 
parameters"""
-    return conv(data, kernel, strides, padding, dilation, 1, "NWC", 
out_dtype=out_dtype)
+    return conv(data, kernel, strides, padding, dilation, 1, "NWC", "WIO", 
out_dtype=out_dtype)
 
 
 def conv1d_ncw(data, kernel, strides=1, padding="VALID", dilation=1, 
out_dtype=None):
     """1D convolution in NCW layout. See :py:func:`conv` for details on 
parameters"""
-    return conv(data, kernel, strides, padding, dilation, 1, "NCW", 
out_dtype=out_dtype)
+    return conv(data, kernel, strides, padding, dilation, 1, "NCW", "OIW", 
out_dtype=out_dtype)
 
 
 def group_conv1d_nwc(
@@ -89,7 +102,7 @@ def group_conv1d_nwc(
     out_dtype : str
         The output data type. If None then output is same type as input.
     """
-    return conv(data, kernel, strides, padding, dilation, groups, "NWC", 
out_dtype=out_dtype)
+    return conv(data, kernel, strides, padding, dilation, groups, "NWC", 
"WIO", out_dtype=out_dtype)
 
 
 def group_conv1d_ncw(
@@ -121,4 +134,4 @@ def group_conv1d_ncw(
     out_dtype : str
         The output data type. If None then output is same type as input.
     """
-    return conv(data, kernel, strides, padding, dilation, groups, "NCW", 
out_dtype=out_dtype)
+    return conv(data, kernel, strides, padding, dilation, groups, "NCW", 
"OIW", out_dtype=out_dtype)
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index d23b8d857e..5070c84c7e 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -57,16 +57,18 @@ Workload = namedtuple(
 )
 
 
-def conv2d(input, filter, strides, padding, dilation, layout="NCHW", 
out_dtype=None):
+def conv2d(
+    input, filter, strides, padding, dilation, data_layout="NCHW", 
kernel_layout="", out_dtype=None
+):
     """Conv2D operator.
 
     Parameters
     ----------
     input : tvm.te.Tensor
-        4-D with shape [batch, in_channel, in_height, in_width]
+        4-D with shape [batch, in_channel, in_height, in_width] in data_layout
 
     filter : tvm.te.Tensor
-        4-D with shape [num_filter, in_channel, filter_height, filter_width]
+        4-D with shape [num_filter, in_channel, filter_height, filter_width] 
in kernel_layout
 
     strides : int or a list/tuple of two ints
         stride size, or [stride_height, stride_width]
@@ -79,9 +81,13 @@ def conv2d(input, filter, strides, padding, dilation, 
layout="NCHW", out_dtype=N
     dilation: int or a list/tuple of two ints
         dilation size, or [dilation_height, dilation_width]
 
-    layout : str
+    data_layout : str
         layout of data
 
+    kernel_layout : Optional[str]
+        layout of kernel. If unspecified, use default layout inferred from 
data_layout. "OIHW" if
+        data_layout == "NCHW", "HWIO" if data_layout == "NHWC".
+
     Returns
     -------
     output : tvm.te.Tensor
@@ -89,7 +95,7 @@ def conv2d(input, filter, strides, padding, dilation, 
layout="NCHW", out_dtype=N
     """
     # search platform specific declaration first
     # default declaration
-    return conv(input, filter, strides, padding, dilation, 1, layout, 
out_dtype)
+    return conv(input, filter, strides, padding, dilation, 1, data_layout, 
kernel_layout, out_dtype)
 
 
 @tvm.target.generic_func
@@ -239,7 +245,7 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, 
out_dtype=None):
     Output : tvm.te.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    return conv(Input, Filter, stride, padding, dilation, 1, "NCHW", 
out_dtype=out_dtype)
+    return conv(Input, Filter, stride, padding, dilation, 1, "NCHW", "OIHW", 
out_dtype=out_dtype)
 
 
 def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
@@ -269,7 +275,7 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, 
out_dtype=None):
     output : tvm.te.Tensor
         4-D with shape [out_height, out_width, out_channel, batch]
     """
-    return conv(Input, Filter, stride, padding, dilation, 1, "HWCN", 
out_dtype=out_dtype)
+    return conv(Input, Filter, stride, padding, dilation, 1, "HWCN", "HWIO", 
out_dtype=out_dtype)
 
 
 def conv2d_nhwc(
@@ -325,6 +331,7 @@ def conv2d_nhwc(
         dilation,
         1,
         "NHWC",
+        "HWIO",
         out_dtype,
         auto_scheduler_rewritten_layout,
         meta_schedule_original_shape,
@@ -708,7 +715,9 @@ def group_conv2d_nchw(Input, Filter, stride, padding, 
dilation, groups, out_dtyp
     Output : tvm.te.Tensor
         4-D with shape [batch, out_channel, out_height, out_width]
     """
-    return conv(Input, Filter, stride, padding, dilation, groups, "NCHW", 
out_dtype=out_dtype)
+    return conv(
+        Input, Filter, stride, padding, dilation, groups, "NCHW", "OIHW", 
out_dtype=out_dtype
+    )
 
 
 def conv(
@@ -718,7 +727,8 @@ def conv(
     padding: Union[int, Sequence[int]],
     dilation: Union[int, Sequence[int]],
     groups: int,
-    order: str,
+    data_layout: str,
+    kernel_layout: str = "",
     out_dtype: Union[str, None] = None,
     auto_scheduler_rewritten_layout: Optional[str] = None,
     meta_schedule_original_shape=None,
@@ -731,11 +741,11 @@ def conv(
     Parameters
     ----------
     inp : tvm.te.Tensor
-        N-D with shape [batch, in_channel, in_height, in_width, ...] ordered 
by `order`
+        N-D with shape [batch, in_channel, in_height, in_width, ...] in 
`data_layout`
 
     filt : tvm.te.Tensor
-        N-D with shape [num_filter, in_channel // groups, filter_height, 
filter_width, ...]
-        for NCHW or [filter_height, filter_width, ..., in_channel // groups, 
num_filter] for NHWC
+        N-D with shape [num_filter, in_channel // groups, filter_height, 
filter_width, ...] in
+        `kernel_layout`
 
     stride : int or a list/tuple of dim ints
         (where dim=2 for NCHW, dim=1 for NCH, etc.)
@@ -753,10 +763,16 @@ def conv(
     groups : int
         number of groups
 
-    order : str
-        Ordering of dimensions. N indicates batch dimension, C indicates
+    data_layout : str
+        Layout of the input. N indicates batch dimension, C indicates
         channels, any other character indicates HW (or H or HWD for 1D and 3D).
 
+    kernel_layout: Optional[str]
+        Layout of the filter. I indicates input channels, O indicates output 
channels,
+        any other character indicates HW dimension of the filter (or H or HWD 
for 1D and 3D).
+        If kernel_layout is empty, use data_layout to infer the default 
kernel_layout. Default
+        kernel_layout is OIHW for NCHW data layout, HWIO for NHWC data layout.
+
     out_dtype : str
         Elements are converted to this type before elementwise multiplication
         and summation.
@@ -775,7 +791,7 @@ def conv(
     Returns
     -------
     Output : tvm.te.Tensor
-        N-D with shape [batch, out_channel, out_height, out_width, ...] 
ordered by `order`.
+        N-D with shape [batch, out_channel, out_height, out_width, ...] in 
`data_layout`
     """
     dim = len(inp.shape) - 2
     if out_dtype is None:
@@ -792,30 +808,41 @@ def conv(
     else:
         dilations = list(dilation)
 
-    # transform from order to NCHW
-    permutation_to = [order.find("N"), order.find("C")] + [
-        x.span()[0] for x in re.finditer("[^NC]", order)
+    # transform from data_layout to NCHW
+    data_permutation_to = [data_layout.find("N"), data_layout.find("C")] + [
+        x.span()[0] for x in re.finditer("[^NC]", data_layout)
     ]
-    # transform from NCHW to order
-    permutation_from = np.argsort(permutation_to)
-    # transform from CHW to order
-    permutation_from_reductions = permutation_from[1:].copy()
-    permutation_from_reductions[permutation_from_reductions > 
permutation_from[0]] -= 1
-
-    # kernel permutation, if C appears before HW then num_filter is first, 
otherwise it is last
-    # tkonolige: I don't really understand kernel ordering for NHWC, it seems
-    # like num_filters should match the N dimension
-    if order.find("C") < re.search("[^NC]", order).span()[0]:
-        permutation_to_kernel = [0, 1] + list(range(2, dim + 2))
+    # transform from NCHW to data_layout
+    data_permutation_from = np.argsort(data_permutation_to)
+    # transform from CHW to data_layout
+    data_permutation_from_reductions = data_permutation_from[1:].copy()
+    data_permutation_from_reductions[
+        data_permutation_from_reductions > data_permutation_from[0]
+    ] -= 1
+
+    if kernel_layout == "":
+        # kernel permutation, if C appears before HW then num_filter is first, 
otherwise it is last
+        # tkonolige: I don't really understand kernel ordering for NHWC, it 
seems
+        # like num_filters should match the N dimension
+        if data_layout.find("C") < re.search("[^NC]", data_layout).span()[0]:
+            kernel_permutation_to = [0, 1] + list(range(2, dim + 2))
+        else:
+            kernel_permutation_to = [dim + 1, dim] + list(range(dim))
     else:
-        permutation_to_kernel = [dim + 1, dim] + list(range(dim))
-    permutation_from_kernel = np.argsort(permutation_to_kernel)
+        # transform from kernel_layout to OIHW
+        kernel_permutation_to = [kernel_layout.find("O"), 
kernel_layout.find("I")] + [
+            x.span()[0] for x in re.finditer("[^OI]", kernel_layout)
+        ]
+    # transform from OIHW to kernel_layout
+    kernel_permutation_from = np.argsort(kernel_permutation_to)
 
     if meta_schedule_original_shape:
         auto_scheduler.rewrite_tensor_shape(filt, meta_schedule_original_shape)
-    batch, in_channel, *dimensions = 
np.array(get_const_tuple(inp.shape))[permutation_to].tolist()
+    batch, in_channel, *dimensions = np.array(get_const_tuple(inp.shape))[
+        data_permutation_to
+    ].tolist()
     num_filter, _, *kernel_dimensions = np.array(get_const_tuple(filt.shape))[
-        permutation_to_kernel
+        kernel_permutation_to
     ].tolist()
 
     # Autoscheduler may have messed with the input layout, so we extract the
@@ -841,14 +868,14 @@ def conv(
         )
     ]
     # compute graph
-    pad_before = list(np.array([0, 0] + pad_begin)[permutation_from])
-    pad_after = list(np.array([0, 0] + pad_end)[permutation_from])
+    pad_before = list(np.array([0, 0] + pad_begin)[data_permutation_from])
+    pad_after = list(np.array([0, 0] + pad_end)[data_permutation_from])
     temp = pad(inp, pad_before, pad_after, name="pad_temp")
     rc = te.reduce_axis((0, in_channel // groups), name="rc")
     rs = [te.reduce_axis((0, k), name=f"r{i}") for i, k in zip(["y", "x", 
"z"], kernel_dimensions)]
 
     def compute(*args):
-        nn, ff, *dim_indices = list(np.array(args)[permutation_to])
+        nn, ff, *dim_indices = list(np.array(args)[data_permutation_to])
 
         if groups == 1:
             simplified_channel_index = rc
@@ -864,25 +891,25 @@ def conv(
                             di * stride + r * dil
                             for di, stride, r, dil in zip(dim_indices, 
strides, rs, dilations)
                         ]
-                    )[permutation_from]
+                    )[data_permutation_from]
                 )
             ).astype(out_dtype)
-            * filt.__getitem__(tuple(np.array([ff, rc] + 
rs)[permutation_from_kernel])).astype(
+            * filt.__getitem__(tuple(np.array([ff, rc] + 
rs)[kernel_permutation_from])).astype(
                 out_dtype
             ),
             # Schedules depend on reduction axes being in the same order as the
             # layout, so we reorder here.
-            axis=np.array([rc, *rs])[permutation_from_reductions].tolist(),
+            axis=np.array([rc, 
*rs])[data_permutation_from_reductions].tolist(),
         )
 
     out = te.compute(
-        list(np.array([batch, out_channel] + 
out_dimensions)[permutation_from]),
+        list(np.array([batch, out_channel] + 
out_dimensions)[data_permutation_from]),
         compute,
         # tag is expected to be lowercase
-        tag=f"{'group_' if groups > 1 else ''}conv{dim}d_{order.lower()}",
-        name=f"{'group_' if groups > 1 else ''}conv{dim}d_{order.lower()}",
+        tag=f"{'group_' if groups > 1 else 
''}conv{dim}d_{data_layout.lower()}",
+        name=f"{'group_' if groups > 1 else 
''}conv{dim}d_{data_layout.lower()}",
         attrs={"layout_free_placeholders": [filt]} if 
auto_scheduler_should_rewrite_layout else {},
-        varargs_names=list(np.array(["nn", "ff", "yy", "xx", 
"zz"])[permutation_from]),
+        varargs_names=list(np.array(["nn", "ff", "yy", "xx", 
"zz"])[data_permutation_from]),
     )
     # if we used autoscheduler's changed layout we need to rewrite the ordering
     # of the output dimensions
@@ -924,7 +951,9 @@ def group_conv2d_nhwc(Input, Filter, stride, padding, 
dilation, groups, out_dtyp
     Output : tvm.te.Tensor
         4-D with shape [batch, out_height, out_width, out_channel]
     """
-    return conv(Input, Filter, stride, padding, dilation, groups, "NHWC", 
out_dtype=out_dtype)
+    return conv(
+        Input, Filter, stride, padding, dilation, groups, "NHWC", "HWIO", 
out_dtype=out_dtype
+    )
 
 
 def unpack_NCHWc_to_nchw(packed_out, out_dtype):
diff --git a/python/tvm/topi/nn/conv3d.py b/python/tvm/topi/nn/conv3d.py
index 591c643a95..1897484dc8 100644
--- a/python/tvm/topi/nn/conv3d.py
+++ b/python/tvm/topi/nn/conv3d.py
@@ -53,7 +53,7 @@ def conv3d_ncdhw(Input, Filter, stride, padding, dilation, 
groups, out_dtype=Non
     Output : tvm.te.Tensor
         5-D with shape [batch, out_channel, out_depth, out_height, out_width]
     """
-    return conv(Input, Filter, stride, padding, dilation, groups, "NCDHW", 
out_dtype)
+    return conv(Input, Filter, stride, padding, dilation, groups, "NCDHW", 
"OIDHW", out_dtype)
 
 
 def conv3d_ndhwc(
@@ -111,6 +111,7 @@ def conv3d_ndhwc(
         dilation,
         groups,
         "NDHWC",
+        "DHWIO",
         out_dtype,
         auto_scheduler_rewritten_layout,
         meta_schedule_origin_shape,
diff --git a/tests/python/integration/test_winograd_nnpack.py 
b/tests/python/integration/test_winograd_nnpack.py
index b088b350c9..9d9f4e10e6 100644
--- a/tests/python/integration/test_winograd_nnpack.py
+++ b/tests/python/integration/test_winograd_nnpack.py
@@ -86,7 +86,7 @@ def verify_conv2d_nchw(
                 stride,
                 padding,
                 dilation,
-                layout="NCHW",
+                data_layout="NCHW",
                 out_dtype=dtype,
             )
             if add_bias:
diff --git a/tests/python/topi/python/test_topi_conv2d_nhwc.py 
b/tests/python/topi/python/test_topi_conv2d_nhwc.py
index 362de3a769..e60cf12aa8 100644
--- a/tests/python/topi/python/test_topi_conv2d_nhwc.py
+++ b/tests/python/topi/python/test_topi_conv2d_nhwc.py
@@ -77,7 +77,7 @@ def ref_data(dtype, batch, in_channel, in_size, num_filter, 
kernel, stride, padd
     return a_np, w_np, b_np
 
 
-def test_conv2d_nhwc(target, dev, ref_data, dtype, stride, padding, dilation):
+def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, 
dilation):
     a_np, w_np, b_np = ref_data
 
     A = te.placeholder(a_np.shape, name="A", dtype=dtype)
@@ -95,5 +95,34 @@ def test_conv2d_nhwc(target, dev, ref_data, dtype, stride, 
padding, dilation):
     tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
 
 
+def test_conv2d_nhwc_ohwi(ref_data, dtype, stride, padding, dilation):
+    # only test on CPU target because topi doesn't have schedules for this 
layout
+    target = "llvm"
+    dev = tvm.device(target, 0)
+    a_np, w_np_hwio, b_np = ref_data
+    w_np_ohwi = w_np_hwio.transpose(3, 0, 1, 2)  # HWIO -> OHWI
+
+    A = te.placeholder(a_np.shape, name="A", dtype=dtype)
+    W = te.placeholder(w_np_ohwi.shape, name="W", dtype=dtype)
+
+    B = topi.nn.conv2d(
+        A,
+        W,
+        stride,
+        padding,
+        dilation,
+        data_layout="NHWC",
+        kernel_layout="OHWI",
+        out_dtype="float32",
+    )
+    s = tvm.te.create_schedule(B.op)
+    a = tvm.nd.array(a_np, dev)
+    w = tvm.nd.array(w_np_ohwi, dev)
+    b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
+    func = tvm.build(s, [A, W, B], target)
+    func(a, w, b)
+    tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/vta/python/vta/top/op.py b/vta/python/vta/top/op.py
index 6b06d88096..4fa5b6ff84 100644
--- a/vta/python/vta/top/op.py
+++ b/vta/python/vta/top/op.py
@@ -214,7 +214,7 @@ def conv2d_strategy_vta(attrs, inputs, out_type, target):
             assert kernel.dtype == "int8"
 
             strategy.add_implementation(
-                _strategy.wrap_compute_conv2d(conv2d_packed, True),
+                _strategy.wrap_compute_conv2d(conv2d_packed, 
need_data_layout=True),
                 _strategy.wrap_topi_schedule(schedule_conv2d_packed),
                 name="conv2d_packed.vta",
             )

Reply via email to