This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6159b8e  [Topi][Op][PyTorch][Vitas] Fix inconsistent kernel layout 
conventions for conv2d_transpose (#9336)
6159b8e is described below

commit 6159b8e925bb77324ecd488e161a6c368b9d4ce7
Author: AndrewZhaoLuo <[email protected]>
AuthorDate: Thu Nov 11 16:45:54 2021 -0800

    [Topi][Op][PyTorch][Vitas] Fix inconsistent kernel layout conventions for 
conv2d_transpose (#9336)
    
    * fix a lot of initial tests
    
    * make pytorch tests pass
    
    * lint
    
    * add test
    
    * fix bug with layout transform
    
    * change layouts for conv2d_transpose too
    
    * fix vitis tests
    
    * fix qnn conv2d transpose tests
    
    * fix fake quantization pass
    
    * add todo
    
    * lint
    
    * undo just formatting changes
    
    * remove formatting only change
    
    * remove f2qi for later pr
    
    * more frontend tests fixes
    
    * fix a lot of initial tests
    
    * make pytorch tests pass
    
    * lint
    
    * add test
    
    * fix bug with layout transform
    
    * change layouts for conv2d_transpose too
    
    * fix vitis tests
    
    * fix qnn conv2d transpose tests
    
    * fix fake quantization pass
    
    * add todo
    
    * lint
    
    * undo just formatting changes
    
    * remove formatting only change
    
    * remove f2qi for later pr
    
    * more frontend tests fixes
    
    * jostle
    
    * fix keras
    
    * fix another frontend test
    
    * fix things
    
    * jostle ci
---
 python/tvm/relay/frontend/caffe.py                 |  7 +-
 python/tvm/relay/frontend/keras.py                 |  7 +-
 python/tvm/relay/frontend/mxnet.py                 | 46 ++++++++-----
 python/tvm/relay/frontend/pytorch.py               | 11 +--
 python/tvm/relay/frontend/qnn_torch.py             |  8 +--
 python/tvm/relay/frontend/tensorflow_ops.py        |  5 +-
 python/tvm/relay/frontend/tflite.py                | 39 ++++++-----
 python/tvm/relay/op/nn/nn.py                       |  2 +-
 python/tvm/relay/qnn/op/layout_conversions.py      |  2 +-
 python/tvm/relay/qnn/op/qnn.py                     |  8 ++-
 python/tvm/relay/testing/dcgan.py                  |  3 +-
 python/tvm/topi/nn/conv2d_transpose.py             | 80 ++++++++++++++--------
 src/relay/op/nn/convolution.h                      |  6 +-
 .../contrib/test_vitis_ai/test_vitis_ai_codegen.py | 25 +++++--
 tests/python/relay/test_op_level2.py               | 60 +++++++++-------
 tests/python/relay/test_op_qnn_conv2_transpose.py  | 53 +++++++-------
 tests/python/relay/test_pass_convert_op_layout.py  | 16 ++---
 17 files changed, 223 insertions(+), 155 deletions(-)

diff --git a/python/tvm/relay/frontend/caffe.py 
b/python/tvm/relay/frontend/caffe.py
index be76fee..30327e5 100644
--- a/python/tvm/relay/frontend/caffe.py
+++ b/python/tvm/relay/frontend/caffe.py
@@ -21,11 +21,12 @@
 import numpy as np
 import tvm
 from tvm.ir import IRModule
+
+from ... import nd as _nd
 from .. import analysis
 from .. import expr as _expr
 from .. import function as _function
 from .. import op as _op
-from ... import nd as _nd
 from .common import ExprTable
 from .common import infer_shape as _infer_shape
 
@@ -514,6 +515,9 @@ class OperatorConverter(object):
             weight_shape = [-1, conv_params.num_output, kh, kw]
             weight_value = np.asarray(weight.data, np.float32)
             weight_value = np.reshape(weight_value, weight_shape)
+
+            # weight shape is in relay's IOHW format rn, we need it to be OIHW
+            weight_value = np.transpose(weight_value, [1, 0, 2, 3])
         else:
             raise Exception("No weight value of layer {} in 
caffemodel".format(op.name))
 
@@ -521,7 +525,6 @@ class OperatorConverter(object):
         in_expr = self.exp_tab.get_expr(inputs[0])
         out = _op.nn.conv2d_transpose(data=in_expr, weight=weight_expr, 
**params)
         if bias:
-
             bias_value = np.asarray(bias.data, np.float32)
             bias_expr = self.exp_tab.new_const(bias_value, dtype="float32")
             out = _op.nn.bias_add(out, bias_expr)
diff --git a/python/tvm/relay/frontend/keras.py 
b/python/tvm/relay/frontend/keras.py
index 06a8dd3..901eee3 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -355,11 +355,14 @@ def _convert_convolution(inexpr, keras_layer, etab):
         else:
             kernel_layout = "HWIO"
     else:
-        kernel_layout = "OIHW"
+        if is_deconv:
+            kernel_layout = "IOHW"
+        else:
+            kernel_layout = "OIHW"
 
     if is_deconv:
         kernel_h, kernel_w, n_filters, in_channels = weight.shape
-        if kernel_layout == "OIHW":
+        if kernel_layout == "IOHW":
             weight = weight.transpose([3, 2, 0, 1])
     elif is_depthconv:
         kernel_h, kernel_w, in_channels, depth_mult = weight.shape
diff --git a/python/tvm/relay/frontend/mxnet.py 
b/python/tvm/relay/frontend/mxnet.py
index 59b4e99..1b1d601 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -18,40 +18,50 @@
 """MXNet symbol frontend."""
 import json
 import math
+
 import numpy as np
 import tvm
-from tvm.ir import IRModule
-
 from tvm import relay
+from tvm.ir import IRModule
 from tvm.topi.utils import get_const_tuple
+
+from ... import nd as _nd
 from .. import analysis
 from .. import expr as _expr
 from .. import function as _function
 from .. import op as _op
 from .. import scope_builder as _scope_builder
-from ... import nd as _nd
-
 from .common import StrAttrsDict
-from .common import infer_type as _infer_type
+from .common import get_name as _get_name
 from .common import infer_shape as _infer_shape
+from .common import infer_type as _infer_type
 from .common import infer_value as _infer_value
-from .common import get_name as _get_name
-from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce
-from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
-from .nnvm_common import _clip, _transpose, _upsampling
-from .nnvm_common import _elemwise_sum, _reshape
-from .nnvm_common import _warn_not_used
 from .mxnet_qnn_op_utils import (
-    quantize_mxnet_min_max,
-    quantize_conv_weights_bias_channel_mkldnn_from_var,
-    quantize_conv_bias_mkldnn_from_var,
-    get_conv_mkldnn_requantized_scale_outDtype,
     dequantize_mxnet_min_max,
+    get_conv_mkldnn_requantized_scale_outDtype,
     get_mkldnn_int8_scale,
-    get_mkldnn_uint8_scale,
     get_mkldnn_requantize_scale_outDtype,
+    get_mkldnn_uint8_scale,
+    quantize_conv_bias_mkldnn_from_var,
+    quantize_conv_weights_bias_channel_mkldnn_from_var,
+    quantize_mxnet_min_max,
+)
+from .nnvm_common import (
+    _arg_reduce,
+    _binop_scalar,
+    _cast,
+    _clip,
+    _elemwise_sum,
+    _init_op,
+    _rbinop_scalar,
+    _reduce,
+    _rename,
+    _reshape,
+    _softmax_op,
+    _transpose,
+    _upsampling,
+    _warn_not_used,
 )
-
 
 __all__ = ["from_mxnet"]
 
@@ -329,7 +339,7 @@ def _mx_conv2d_transpose(inputs, attrs):
     if "kernel_layout" in attrs.attrs:
         kernel_layout = attrs.get_str("kernel_layout")
     else:
-        kernel_layout = "HWIO" if data_layout == "NHWC" else "OIHW"
+        kernel_layout = "HWIO" if data_layout == "NHWC" else "IOHW"
 
     new_attrs = {}
     new_attrs["channels"] = attrs.get_int("num_filter")
diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 13704ff..a17a10e 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -19,8 +19,8 @@
 # pylint: disable=import-outside-toplevel, simplifiable-if-expression, 
cell-var-from-loop, unnecessary-lambda
 # pylint: disable=missing-function-docstring
 """PT: PyTorch frontend."""
-import itertools
 import functools
+import itertools
 import logging
 import math
 import sys
@@ -40,11 +40,11 @@ from ..loops import while_loop
 from ..prelude import Prelude, StaticTensorArrayOps
 from ..ty import Any, TensorType, TupleType
 from . import qnn_torch
-from .common import AttrCvt, get_relay_op, unbind, lstm_cell, gru_cell
-from .common import infer_value as _infer_value
+from .common import AttrCvt, get_relay_op, gru_cell
 from .common import infer_shape as _infer_shape
+from .common import infer_value as _infer_value
 from .common import infer_value_simulated as _infer_value_simulated
-from .common import try_infer_value
+from .common import lstm_cell, try_infer_value, unbind
 from .pytorch_utils import is_version_greater_than
 
 __all__ = ["from_pytorch"]
@@ -1010,6 +1010,9 @@ class PyTorchOpConverter:
         elif len(kernel_size) == 2:
             data_layout = "NCHW"
             kernel_layout = "OIHW"
+            if use_transpose:
+                # Transposed convolutions have IOHW layout.
+                kernel_layout = "IOHW"
         else:
             data_layout = "NCW"
             kernel_layout = "OIW"
diff --git a/python/tvm/relay/frontend/qnn_torch.py 
b/python/tvm/relay/frontend/qnn_torch.py
index 172ab1e..5772313 100644
--- a/python/tvm/relay/frontend/qnn_torch.py
+++ b/python/tvm/relay/frontend/qnn_torch.py
@@ -19,7 +19,6 @@
 import logging
 
 import numpy as np
-
 import tvm
 from tvm import relay
 from tvm.relay import expr as _expr
@@ -1043,11 +1042,8 @@ def _quantized_conv_transpose2d(with_relu=False):
 
         weight_shape = list(infer_shape(weight))
 
-        # Swap I and O dims to match shape relay expects for OIHW
-        weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]
-
         kernel_size = (weight_shape[2], weight_shape[3])
-        out_channels = weight_shape[0]
+        out_channels = weight_shape[1]
 
         conv_out = relay.qnn.op.conv2d_transpose(
             inputs[0],
@@ -1064,7 +1060,7 @@ def _quantized_conv_transpose2d(with_relu=False):
             channels=out_channels,
             output_padding=output_padding,
             out_dtype="int32",
-            kernel_layout="OIHW",
+            kernel_layout="IOHW",
         )
 
         return _do_bias_and_requantize(
diff --git a/python/tvm/relay/frontend/tensorflow_ops.py 
b/python/tvm/relay/frontend/tensorflow_ops.py
index a8213d4..26ea4f4 100644
--- a/python/tvm/relay/frontend/tensorflow_ops.py
+++ b/python/tvm/relay/frontend/tensorflow_ops.py
@@ -461,8 +461,11 @@ def _conv(opname):
             raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
 
         if "kernel_layout" not in attr:
-            if opname in ["conv", "conv_transpose"]:
+            if opname == "conv":
                 attr["kernel_layout"] = "HWIO" if attr["data_format"] == 
"NHWC" else "OIHW"
+            elif opname == "conv_transpose":
+                # conv_transpose in TVM has weights be IOHW for NCHW
+                attr["kernel_layout"] = "HWIO" if attr["data_format"] == 
"NHWC" else "IOHW"
             else:
                 attr["kernel_layout"] = "HWOI" if attr["data_format"] == 
"NHWC" else "OIHW"
 
diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index 5b32717..05b3041 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -16,24 +16,25 @@
 # under the License.
 # pylint: disable=invalid-name, unused-argument, too-many-lines, 
import-outside-toplevel
 """Tensorflow lite frontend."""
-import math
 import itertools
+import math
+
 import numpy as np
 import tvm
+from tvm import relay
 from tvm.ir import IRModule
 
-from tvm import relay
+from ... import nd as _nd
 from .. import analysis
 from .. import expr as _expr
 from .. import function as _function
 from .. import op as _op
 from .. import qnn as _qnn
-from ... import nd as _nd
 from .common import ExprTable
-from .common import infer_shape as _infer_shape, to_int_list
+from .common import infer_shape as _infer_shape
+from .common import to_int_list
 from .tflite_flexbuffer import FlexBufferDecoder
 
-
 __all__ = ["from_tflite"]
 
 
@@ -53,9 +54,9 @@ class OperatorConverter(object):
     def __init__(self, model, subgraph, exp_tab):
 
         try:
+            from tflite.ActivationFunctionType import ActivationFunctionType
             from tflite.BuiltinOperator import BuiltinOperator
             from tflite.BuiltinOptions import BuiltinOptions
-            from tflite.ActivationFunctionType import ActivationFunctionType
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
@@ -1061,8 +1062,8 @@ class OperatorConverter(object):
     def convert_concatenation(self, op):
         """Convert TFLite concatenation"""
         try:
-            from tflite.ConcatenationOptions import ConcatenationOptions
             from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.ConcatenationOptions import ConcatenationOptions
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
@@ -1242,10 +1243,10 @@ class OperatorConverter(object):
         """Generic method to Convert TFLite elemwise"""
         try:
             from tflite.AddOptions import AddOptions
-            from tflite.SubOptions import SubOptions
-            from tflite.MulOptions import MulOptions
-            from tflite.DivOptions import DivOptions
             from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.DivOptions import DivOptions
+            from tflite.MulOptions import MulOptions
+            from tflite.SubOptions import SubOptions
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
@@ -1804,9 +1805,9 @@ class OperatorConverter(object):
     def _convert_arg_min_max(self, relay_op, op):
         """Generic method converting TFLite arg_min_max"""
         try:
-            from tflite.BuiltinOptions import BuiltinOptions
-            from tflite.ArgMinOptions import ArgMinOptions
             from tflite.ArgMaxOptions import ArgMaxOptions
+            from tflite.ArgMinOptions import ArgMinOptions
+            from tflite.BuiltinOptions import BuiltinOptions
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
@@ -1853,8 +1854,8 @@ class OperatorConverter(object):
     def convert_fully_connected(self, op):
         """Convert TFLite fully connected"""
         try:
-            from tflite.FullyConnectedOptions import FullyConnectedOptions
             from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.FullyConnectedOptions import FullyConnectedOptions
             from tflite.TensorType import TensorType
         except ImportError:
             raise ImportError("The tflite package must be installed")
@@ -2024,10 +2025,10 @@ class OperatorConverter(object):
         """convolution implementation."""
         try:
             from tflite.BuiltinOptions import BuiltinOptions
-            from tflite.TensorType import TensorType
             from tflite.Conv2DOptions import Conv2DOptions
             from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions
             from tflite.Padding import Padding
+            from tflite.TensorType import TensorType
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
@@ -2434,8 +2435,8 @@ class OperatorConverter(object):
         """pool2d implementation."""
         try:
             from tflite.BuiltinOptions import BuiltinOptions
-            from tflite.Pool2DOptions import Pool2DOptions
             from tflite.Padding import Padding
+            from tflite.Pool2DOptions import Pool2DOptions
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
@@ -2850,9 +2851,9 @@ class OperatorConverter(object):
         """Convert TFLite TRANSPOSE_CONV"""
         try:
             from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.Padding import Padding
             from tflite.TensorType import TensorType
             from tflite.TransposeConvOptions import TransposeConvOptions
-            from tflite.Padding import Padding
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
@@ -2946,7 +2947,7 @@ class OperatorConverter(object):
                 channels=int(out_channels),
                 kernel_size=(int(kernel_h), int(kernel_w)),
                 data_layout="NHWC",
-                kernel_layout="OIHW",
+                kernel_layout="IOHW",
                 out_dtype="int32",
             )
         else:
@@ -2958,7 +2959,7 @@ class OperatorConverter(object):
                 channels=int(out_channels),
                 kernel_size=(int(kernel_h), int(kernel_w)),
                 data_layout="NHWC",
-                kernel_layout="OIHW",
+                kernel_layout="IOHW",
                 out_dtype=output_tensor_type_str,
             )
 
@@ -3717,8 +3718,8 @@ def from_tflite(model, shape_dict=None, dtype_dict=None, 
op_converter=OperatorCo
         The parameter dict to be used by relay
     """
     try:
-        import tflite.SubGraph
         import tflite.BuiltinOperator
+        import tflite.SubGraph
     except ImportError:
         raise ImportError("The tflite package must be installed")
 
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 1821ff1..c7b376e 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -522,7 +522,7 @@ def conv2d_transpose(
     channels=None,
     kernel_size=None,
     data_layout="NCHW",
-    kernel_layout="OIHW",
+    kernel_layout="IOHW",
     out_layout="",
     output_padding=(0, 0),
     out_dtype="",
diff --git a/python/tvm/relay/qnn/op/layout_conversions.py 
b/python/tvm/relay/qnn/op/layout_conversions.py
index 1a3b177..24c787e 100644
--- a/python/tvm/relay/qnn/op/layout_conversions.py
+++ b/python/tvm/relay/qnn/op/layout_conversions.py
@@ -119,7 +119,7 @@ def convert_qnn_conv2d_transpose(attrs, inputs, tinfos, 
desired_layouts):
 
     # Handle default kernel layouts
     if desired_data_layout == "NCHW":
-        new_attrs["kernel_layout"] = "OIHW"
+        new_attrs["kernel_layout"] = "IOHW"
         return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs)
     if desired_data_layout == "NHWC":
         new_attrs["kernel_layout"] = "HWIO"
diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py
index 83b5cf0..4749b63 100644
--- a/python/tvm/relay/qnn/op/qnn.py
+++ b/python/tvm/relay/qnn/op/qnn.py
@@ -18,13 +18,15 @@
 """QNN dialect operators."""
 
 from __future__ import absolute_import as _abs
+
 from tvm import relay
 from tvm.relay.expr import Tuple, TupleWrapper
 from tvm.relay.op.nn.utils import get_pad_tuple2d
 from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE
-from . import _make
+
 from ... import op as reg
 from ...op import OpPattern
+from . import _make
 
 
 def requantize(
@@ -382,10 +384,10 @@ def conv2d_transpose(
     channels=None,
     kernel_size=None,
     data_layout="NCHW",
-    kernel_layout="OIHW",
+    kernel_layout="IOHW",
     out_layout="",
     output_padding=(0, 0),
-    out_dtype="",
+    out_dtype="int32",
 ):
     """This operator deconvolves quantized data with quantized kernel. The 
scale of
     the output quantized tensor is the product of the kernel_scale and
diff --git a/python/tvm/relay/testing/dcgan.py 
b/python/tvm/relay/testing/dcgan.py
index fc531b7..acc4783 100644
--- a/python/tvm/relay/testing/dcgan.py
+++ b/python/tvm/relay/testing/dcgan.py
@@ -27,6 +27,7 @@ Radford, Alec, Luke Metz, and Soumith Chintala.
 arXiv preprint arXiv:1511.06434 (2015).
 """
 from tvm import relay
+
 from . import layers
 from .init import create_workload
 
@@ -41,7 +42,7 @@ def deconv2d(data, ishape, oshape, kshape, layout, name, 
stride=(2, 2)):
     adj_x = (target_shape[1] + 2 * pad_x - kshape[1]) % stride[1]
 
     if layout == "NCHW":
-        kernel_layout = "OIHW"
+        kernel_layout = "IOHW"
     elif layout == "NHWC":
         kernel_layout = "HWOI"
     else:
diff --git a/python/tvm/topi/nn/conv2d_transpose.py 
b/python/tvm/topi/nn/conv2d_transpose.py
index 22188bc..99c7442 100644
--- a/python/tvm/topi/nn/conv2d_transpose.py
+++ b/python/tvm/topi/nn/conv2d_transpose.py
@@ -17,12 +17,12 @@
 # pylint: disable=invalid-name, unused-variable, unused-argument
 """Transposed 2D convolution operators (sometimes called Deconvolution)."""
 import tvm
-from tvm import te
-from tvm import relay
+from tvm import relay, te
+
+from ..utils import simplify
 from .dilate import dilate
 from .pad import pad
 from .utils import get_pad_tuple
-from ..utils import simplify
 
 
 def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype, 
output_padding):
@@ -116,6 +116,41 @@ def declaration_conv2d_transpose_impl(data, kernel, 
strides, padding, out_dtype,
     return Output
 
 
+def layout_transform(tensor: "relay.Expr", current_layout: str, 
desired_layout: str):
+    """Transform a tensor with the current layout to the desired layout.
+
+    E.g. layout_transform(t, "NCHW", "CNHW") --> relay.transpose(t, [1, 0, 2, 
3])
+
+    Parameters
+    ----------
+    tensor: relay.Expr
+        The Tensor to transpose
+
+    current_layout: str
+        The current layout e.g. NCHW or OIHW
+
+    desired_layout: str
+        The desired layout, must be compatible with current_layout
+
+    Returns
+    -------
+    The layout_transformed tensor.
+    """
+    if sorted(current_layout) != sorted(desired_layout):
+        raise ValueError(f"Incompatible layouts: {current_layout} vs 
{desired_layout}")
+
+    if current_layout == desired_layout:
+        return tensor
+
+    current_layout_map = {c: i for i, c in enumerate(current_layout)}
+    desired_layout_map = {c: i for i, c in enumerate(desired_layout)}
+
+    axes = [None] * len(current_layout)
+    for c, i in desired_layout_map.items():
+        axes[i] = current_layout_map[c]
+    return relay.transpose(tensor, axes=axes)
+
+
 @tvm.target.generic_func
 def conv2d_transpose_legalize(attrs, inputs, types):
     """Legalizes Transposed 2D convolution op.
@@ -134,36 +169,17 @@ def conv2d_transpose_legalize(attrs, inputs, types):
     result : tvm.relay.Expr
         The legalized expr
     """
+
+    data, kernel = inputs
+    kernel_layout = attrs["kernel_layout"]
     if attrs["data_layout"] == "NHWC":
-        data, kernel = inputs
-        kernel_layout = attrs["kernel_layout"]
-        # Convert Kernel layout to IOHW
-        # kernel_layout is different from input kernel layout - IO is swapped
-        if kernel_layout == "HWIO":
-            # input kernel layout is swapped to HWOI
-            # output kernel layout will be IOHW
-            kernel = relay.transpose(kernel, axes=(3, 2, 0, 1))
-        elif kernel_layout == "HWOI":
-            # input kernel layout is swapped to HWIO
-            # output kernel layout will be IOHW
-            kernel = relay.transpose(kernel, axes=(2, 3, 0, 1))
-        elif kernel_layout == "IOHW":
-            # input kernel layout is swapped to OIHW
-            # output kernel layout will be IOHW
-            kernel = relay.transpose(kernel, axes=(1, 0, 2, 3))
-        elif kernel_layout == "OIHW":
-            # input kernel layout is swapped to IOHW
-            # output kernel layout will be IOHW
-            pass
-        else:
-            # Skip legalize. Let relay.nn.conv2d_transpose to handle the case
-            return None
+        kernel = layout_transform(kernel, kernel_layout, "IOHW")
 
         # Set new attrs for conv2d_transpose.
         new_attrs = {k: attrs[k] for k in attrs.keys()}
         new_attrs["data_layout"] = "NCHW"
-        # layout of kernel should be IOHW, but kernel_layout should be swapped 
- OIHW
-        new_attrs["kernel_layout"] = "OIHW"
+        # layout of kernel should be IOHW, but kernel_layout will be swapped - 
OIHW
+        new_attrs["kernel_layout"] = "IOHW"
 
         # Convert data to NCHW.
         data = relay.transpose(data, axes=(0, 3, 1, 2))
@@ -172,4 +188,12 @@ def conv2d_transpose_legalize(attrs, inputs, types):
         out = relay.transpose(deconv, axes=(0, 2, 3, 1))
         return out
 
+    if attrs["data_layout"] == "NCHW":
+        kernel = layout_transform(kernel, kernel_layout, "IOHW")
+        new_attrs = {k: attrs[k] for k in attrs.keys()}
+
+        # layout of kernel should be IOHW, but kernel_layout will be swapped - 
OIHW
+        new_attrs["kernel_layout"] = "IOHW"
+        return relay.nn.conv2d_transpose(data, kernel, **new_attrs)
+
     return None
diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h
index c27227b..d995807 100644
--- a/src/relay/op/nn/convolution.h
+++ b/src/relay/op/nn/convolution.h
@@ -1044,7 +1044,7 @@ bool Conv2DTransposeRel(const Array<Type>& types, int 
num_inputs, const Attrs& a
   if (data == nullptr) return false;
 
   static const Layout kNCHW("NCHW");
-  static const Layout kOIHW("OIHW");
+  static const Layout kIOHW("IOHW");
 
   const Conv2DTransposeAttrs* param = attrs.as<AttrType>();
   ICHECK(param != nullptr);
@@ -1056,9 +1056,9 @@ bool Conv2DTransposeRel(const Array<Type>& types, int 
num_inputs, const Attrs& a
       << "Conv only support input layouts that are convertible from NCHW."
       << " But got " << in_layout;
 
-  const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
+  const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIOHW);
   ICHECK(trans_kernel_layout.defined())
-      << "Conv only support kernel layouts that are convertible from OIHW."
+      << "Conv only support kernel layouts that are convertible from IOHW."
       << " But got " << kernel_layout;
 
   Layout out_layout(param->out_layout == "" ? param->data_layout : 
param->out_layout);
diff --git a/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py 
b/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py
index c5c9cc7..b89cc37 100644
--- a/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py
+++ b/tests/python/contrib/test_vitis_ai/test_vitis_ai_codegen.py
@@ -19,24 +19,23 @@
 """Vitis-AI codegen tests"""
 
 import sys
-import numpy as np
 
+import numpy as np
 import pytest
 
 pytest.importorskip("pyxir")
 import pyxir.contrib.target.DPUCADF8H
 import pyxir.contrib.target.DPUCAHX8H
 import pyxir.contrib.target.DPUCAHX8L
-import pyxir.contrib.target.DPUCVDX8H
 import pyxir.contrib.target.DPUCVDX8G
+import pyxir.contrib.target.DPUCVDX8H
 import pyxir.contrib.target.DPUCZDX8G
-
 import tvm
 from tvm import relay
+from tvm.contrib.target import vitis_ai
 from tvm.relay import transform
-from tvm.relay.op.contrib.vitis_ai import annotation
 from tvm.relay.build_module import bind_params_by_name
-from tvm.contrib.target import vitis_ai
+from tvm.relay.op.contrib.vitis_ai import annotation
 
 from .infrastructure import skip_test, verify_codegen
 
@@ -241,6 +240,13 @@ def test_upsampling(dpu_target):
     verify_codegen(mod, dpu_target=dpu_target)
 
 
[email protected](
+    reason="I and O used to be mixed up in kernel layouts in TVM."
+    "This is fixed, but vitis needs to adopt the new convention."
+    "To change, simply remove this line:"
+    
"https://github.com/Xilinx/pyxir/blob/bef661d6d77adcdbd2cf4163f2cf3a1d31d40406/";
+    "python/pyxir/frontend/tvm/relay_tools/relay_l2_convolution.py#L380"
+)
 @pytest.mark.parametrize(
     "dpu_target",
     ["DPUCADF8H", "DPUCAHX8H-u50", "DPUCAHX8L", "DPUCVDX8H", "DPUCVDX8G", 
"DPUCZDX8G-zcu104"],
@@ -253,7 +259,14 @@ def test_conv2d_transpose(dpu_target):
     x = relay.var("x", shape=dshape)
     w = relay.const(np.zeros(kshape, dtype="float32"))
     y = relay.nn.conv2d_transpose(
-        x, w, channels=10, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1)
+        x,
+        w,
+        channels=10,
+        kernel_size=(3, 3),
+        strides=(1, 1),
+        padding=(1, 1),
+        data_layout="NCHW",
+        kernel_layout="IOHW",
     )
     func = relay.Function([x], y)
     params = {}
diff --git a/tests/python/relay/test_op_level2.py 
b/tests/python/relay/test_op_level2.py
index da28770..db712be 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -20,13 +20,12 @@ import sys
 
 import numpy as np
 import pytest
-
 import tvm
 import tvm.testing
 import tvm.topi.testing
-
 from tvm import autotvm, relay, te
 from tvm.contrib import utils
+from tvm.ir.module import IRModule
 from tvm.relay import transform
 from tvm.relay.testing import run_infer_type
 from tvm.topi.cuda.conv3d_winograd import _infer_tile_size
@@ -838,25 +837,42 @@ def test_conv2d_transpose_infer_type():
 
 @tvm.testing.uses_gpu
 def test_conv2d_transpose_nchw_run():
-    dshape = (1, 3, 18, 18)
-    kshape = (3, 10, 3, 3)
-    oshape = (1, 10, 36, 36)
-    x = relay.var("x", shape=dshape)
-    w = relay.var("w")
-    y = relay.nn.conv2d_transpose(
-        x, w, channels=10, kernel_size=(3, 3), strides=(2, 2), padding=(1, 1), 
output_padding=(1, 1)
-    )
-    func = relay.Function([x, w], y)
-    dtype = "float32"
-    data = np.random.uniform(size=dshape).astype(dtype)
-    kernel = np.random.uniform(size=kshape).astype(dtype)
-    ref_res = tvm.topi.testing.conv2d_transpose_nchw_python(data, kernel, 2, 
1, (1, 1))
+    k_layouts = {"OIHW": (10, 3, 3, 3), "IOHW": (3, 10, 3, 3)}
 
-    for target, dev in tvm.testing.enabled_targets():
-        op_res1 = relay.create_executor("graph", device=dev, 
target=target).evaluate(func)(
-            data, kernel
+    for k_layout, kshape in k_layouts.items():
+        dshape = (1, 3, 18, 18)
+        oshape = (1, 10, 36, 36)
+        x = relay.var("x", shape=dshape)
+        w = relay.var("w")
+        y = relay.nn.conv2d_transpose(
+            x,
+            w,
+            channels=10,
+            kernel_size=(3, 3),
+            strides=(2, 2),
+            padding=(1, 1),
+            output_padding=(1, 1),
+            kernel_layout=k_layout,
+            data_layout="NCHW",
         )
-        tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, 
atol=1e-5)
+        func = relay.Function([x, w], y)
+        dtype = "float32"
+        data = np.random.uniform(size=dshape).astype(dtype)
+        kernel = np.random.uniform(size=kshape).astype(dtype)
+
+        if k_layout != "IOHW":
+            # Must be OIHW so switch
+            kernel_iohw = np.transpose(kernel, [1, 0, 2, 3])
+        else:
+            kernel_iohw = kernel
+
+        ref_res = tvm.topi.testing.conv2d_transpose_nchw_python(data, 
kernel_iohw, 2, 1, (1, 1))
+
+        for target, dev in tvm.testing.enabled_targets():
+            op_res1 = relay.create_executor("graph", device=dev, 
target=target).evaluate(func)(
+                data, kernel
+            )
+            tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, 
atol=1e-5)
 
 
 @tvm.testing.uses_gpu
@@ -866,8 +882,7 @@ def test_conv2d_transpose_nhwc_run():
     oshape_nhwc = (1, 36, 36, 10)
     x = relay.var("x", shape=dshape_nhwc)
     w = relay.var("w")
-    # kshape and kernel_layout should have swapped IO.
-    # kshape is HWOI and kernel_layout is HWIO
+
     y = relay.nn.conv2d_transpose(
         x,
         w,
@@ -877,13 +892,12 @@ def test_conv2d_transpose_nhwc_run():
         padding=(1, 1),
         output_padding=(1, 1),
         data_layout="NHWC",
-        kernel_layout="HWIO",
+        kernel_layout="HWOI",
     )
     func = relay.Function([x, w], y)
     dtype = "float32"
     data = np.random.uniform(size=dshape_nhwc).astype(dtype)
     kernel = np.random.uniform(size=kshape_hwoi).astype(dtype)
-    # use true kshape layout here - HWOI
 
     ref_res = tvm.topi.testing.conv2d_transpose_nhwc_python(
         data, kernel, "HWOI", 2, 1, output_padding=(1, 1)
diff --git a/tests/python/relay/test_op_qnn_conv2_transpose.py 
b/tests/python/relay/test_op_qnn_conv2_transpose.py
index 9fd3d1b..a7c3c91 100644
--- a/tests/python/relay/test_op_qnn_conv2_transpose.py
+++ b/tests/python/relay/test_op_qnn_conv2_transpose.py
@@ -15,13 +15,12 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import tvm
-from tvm import te
 import numpy as np
-from tvm import relay
+import tvm
+from tvm import relay, te
+from tvm.contrib import graph_executor
 from tvm.relay import transform
 from tvm.relay.testing import run_infer_type
-from tvm.contrib import graph_executor
 from tvm.relay.testing.temp_op_attr import TempOpAttr
 
 
@@ -224,7 +223,7 @@ def test_no_zero_point():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
@@ -248,7 +247,7 @@ def test_no_zero_point():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
@@ -274,7 +273,7 @@ def test_kernel_zero_point():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
@@ -298,7 +297,7 @@ def test_kernel_zero_point():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
@@ -324,7 +323,7 @@ def test_input_zero_point():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
@@ -348,7 +347,7 @@ def test_input_zero_point():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
@@ -374,7 +373,7 @@ def test_both_zero_point():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
@@ -398,7 +397,7 @@ def test_both_zero_point():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
@@ -424,7 +423,7 @@ def test_different_dtype():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
         channels=kernel_shape[1],
     )
@@ -449,7 +448,7 @@ def test_different_dtype():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
         channels=kernel_shape[1],
     )
@@ -460,7 +459,7 @@ def test_layout():
     # uint8 input
     data_shape = (2, 2, 4, 4)  # NHWC
     data_dtype = "uint8"
-    kernel_shape = (2, 2, 3, 4)  # HWIO
+    kernel_shape = (2, 2, 3, 4)  # HWOI
     kernel_dtype = "uint8"
     ref_func, qnn_func = get_funcs(
         data_shape=data_shape,
@@ -476,14 +475,14 @@ def test_layout():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NHWC",
-        kernel_layout="HWIO",
+        kernel_layout="HWOI",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
 
     data_shape = (2, 2, 4, 3)  # NHWC
     data_dtype = "uint8"
-    kernel_shape = (2, 2, 1, 3)  # HWIO
+    kernel_shape = (2, 2, 1, 3)  # HWOI
     kernel_dtype = "uint8"
     ref_func, qnn_func = get_funcs(
         data_shape=data_shape,
@@ -499,7 +498,7 @@ def test_layout():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NHWC",
-        kernel_layout="HWIO",
+        kernel_layout="HWOI",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
@@ -525,7 +524,7 @@ def test_padding():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
@@ -533,7 +532,7 @@ def test_padding():
     # Try different layout
     data_shape = (2, 2, 4, 4)  # NHWC
     data_dtype = "uint8"
-    kernel_shape = (2, 2, 3, 4)  # HWIO
+    kernel_shape = (2, 2, 3, 4)  # HWOI
     kernel_dtype = "uint8"
     ref_func, qnn_func = get_funcs(
         data_shape=data_shape,
@@ -549,7 +548,7 @@ def test_padding():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NHWC",
-        kernel_layout="HWIO",
+        kernel_layout="HWOI",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
@@ -557,7 +556,7 @@ def test_padding():
     # Try asymmetric padding
     data_shape = (2, 8, 6, 4)  # NHWC
     data_dtype = "uint8"
-    kernel_shape = (2, 2, 3, 4)  # HWIO
+    kernel_shape = (2, 2, 3, 4)  # HWOI
     kernel_dtype = "uint8"
     ref_func, qnn_func = get_funcs(
         data_shape=data_shape,
@@ -573,7 +572,7 @@ def test_padding():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NHWC",
-        kernel_layout="HWIO",
+        kernel_layout="HWOI",
         out_dtype="int32",
     )
     verify(ref_func, qnn_func, data_shape, data_dtype, kernel_shape, 
kernel_dtype)
@@ -600,7 +599,7 @@ def test_const_folding():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
         channels=kernel_shape[1],
         groups=1,
@@ -614,7 +613,7 @@ def test_broadcast_layout():
     # Test broadcast support for NHWC layout.
     data_shape = (1, 229, 229, 3)  # NHWC
     data_dtype = "uint8"
-    kernel_shape = (7, 7, 64, 3)  # HWIO
+    kernel_shape = (7, 7, 64, 3)  # HWOI
     kernel_dtype = "int8"
     _, qnn_func = get_funcs(
         data_shape=data_shape,
@@ -630,7 +629,7 @@ def test_broadcast_layout():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NHWC",
-        kernel_layout="HWIO",
+        kernel_layout="HWOI",
         out_dtype="int32",
     )
     func = qnn_func["main"].body
@@ -670,7 +669,7 @@ def test_per_channel_kernel_scale():
         strides=(1, 1),
         dilation=(1, 1),
         data_layout="NCHW",
-        kernel_layout="OIHW",
+        kernel_layout="IOHW",
         out_dtype="int32",
     )
 
diff --git a/tests/python/relay/test_pass_convert_op_layout.py 
b/tests/python/relay/test_pass_convert_op_layout.py
index 2359dcd..5360018 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -16,15 +16,12 @@
 # under the License.
 """Test alter op layout pass"""
 import pytest
-
 import tvm
-from tvm import te
-
-from tvm import relay
+from tvm import relay, te
+from tvm.relay import analysis, transform
+from tvm.relay.op import op as reg
 from tvm.relay.op import register_alter_op_layout
-from tvm.relay import transform, analysis
 from tvm.relay.transform.infer_layout_utils import InferCorrectLayoutOutput
-from tvm.relay.op import op as reg
 
 
 def run_opt_pass(expr, passes):
@@ -182,7 +179,7 @@ def test_conv_transpose_convert_layout():
         x = relay.var("x", shape=(1, 56, 56, 64))
         weight = relay.var("weight", shape=(3, 3, 64, 64))
         x = relay.layout_transform(x, "NHWC", "NCHW")
-        weight = relay.layout_transform(weight, "HWIO", "OIHW")
+        weight = relay.layout_transform(weight, "HWIO", "IOHW")
         y = relay.nn.conv2d_transpose(x, weight, channels=64, kernel_size=(3, 
3), padding=(1, 1))
         y = relay.nn.relu(y)
         y = relay.layout_transform(y, "NCHW", "NHWC")
@@ -190,7 +187,7 @@ def test_conv_transpose_convert_layout():
         return y
 
     a = before()
-    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d_transpose": 
["NCHW", "OIHW"]}))
+    a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d_transpose": 
["NCHW", "IOHW"]}))
     b = run_opt_pass(expected(), transform.InferType())
 
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
@@ -1373,7 +1370,7 @@ def test_qnn_conv_transpose_requantize_convert_layout():
         x = relay.var("x", shape=(1, 56, 56, 64), dtype="int8")
         weight = relay.var("weight", shape=(3, 3, 64, 64), dtype="int8")
         x = relay.layout_transform(x, "NHWC", "NCHW")
-        weight = relay.layout_transform(weight, "HWIO", "OIHW")
+        weight = relay.layout_transform(weight, "HWIO", "IOHW")
         y = relay.qnn.op.conv2d_transpose(
             x,
             weight,
@@ -1403,7 +1400,6 @@ def test_qnn_conv_transpose_requantize_convert_layout():
     a = before()
     a = run_opt_pass(a, transform.ConvertLayout({"qnn.conv2d_transpose": 
["NCHW", "default"]}))
     b = run_opt_pass(expected(), transform.InferType())
-
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 

Reply via email to