This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0fc5fd5118 [Unity][Op] Add repeat, tile, conv2d_transpose, avg_pool2d 
(#14238)
0fc5fd5118 is described below

commit 0fc5fd5118cb22db9b747d8cf3bf6dd17ed52c3b
Author: Yixin Dong <[email protected]>
AuthorDate: Thu Mar 9 21:21:38 2023 +0800

    [Unity][Op] Add repeat, tile, conv2d_transpose, avg_pool2d (#14238)
    
    This PR adds a series of operators for relax:
    - `repeat(data: Expr, repeats: int, axis: Optional[int] = None)`
    - `tile(data: Expr, repeats: Union[int, Tuple[int], List[int]])`
    - `relax.nn.conv2d_transpose`
            - This operator is intended to find the gradient of conv2d w.r.t 
its input. For details see [pytorch 
document](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html),
 and the document in `python/tvm/relax/op/nn/nn.py`.
            - Now TOPI support of conv2d_transpose is limited. It does not 
support dilations; layouts other than default; symbolic `output_padding`.
    - `relax.nn.avg_pool2d`
---
 include/tvm/relax/attrs/manipulate.h               |  23 ++
 include/tvm/relax/attrs/nn.h                       |  55 ++-
 python/tvm/relax/op/manipulate.py                  |  74 ++++
 python/tvm/relax/op/nn/nn.py                       | 182 +++++++++-
 python/tvm/relax/op/op_attrs.py                    |  19 +-
 python/tvm/relax/transform/legalize_ops/common.py  |   2 +-
 python/tvm/relax/transform/legalize_ops/index.py   |   6 +-
 .../tvm/relax/transform/legalize_ops/manipulate.py |  25 ++
 python/tvm/relax/transform/legalize_ops/nn.py      |  58 ++++
 python/tvm/script/ir_builder/relax/ir.py           |   4 +
 src/relax/op/nn/convolution.cc                     | 144 ++++++++
 src/relax/op/nn/convolution.h                      |  11 +
 src/relax/op/nn/pooling.cc                         |  44 ++-
 src/relax/op/nn/pooling.h                          |   4 +
 src/relax/op/tensor/manipulate.cc                  | 126 +++++++
 src/relax/op/tensor/manipulate.h                   |  27 ++
 tests/python/relax/test_op_manipulate.py           | 246 +++++++++++++
 tests/python/relax/test_op_nn_convolution.py       | 385 +++++++++++++++++++++
 tests/python/relax/test_op_nn_pooling.py           | 228 +++++++++++-
 .../test_transform_legalize_ops_manipulate.py      | 199 ++++++++++-
 .../python/relax/test_transform_legalize_ops_nn.py | 347 ++++++++++++++++++-
 .../relax/test_tvmscript_parser_op_manipulate.py   |  45 +++
 tests/python/relax/test_tvmscript_parser_op_nn.py  |  37 ++
 23 files changed, 2264 insertions(+), 27 deletions(-)

diff --git a/include/tvm/relax/attrs/manipulate.h 
b/include/tvm/relax/attrs/manipulate.h
index bd6ae17bcf..4982daf7e4 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -102,6 +102,29 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
   }
 };  // struct SqueezeAttrs
 
+/*! \brief Attributes used in repeat operators */
+struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
+  int repeats;
+  Optional<Integer> axis;
+
+  TVM_DECLARE_ATTRS(RepeatAttrs, "relax.attrs.RepeatAttrs") {
+    TVM_ATTR_FIELD(repeats).describe("The number of repetitions.");
+    TVM_ATTR_FIELD(axis).describe(
+        "The axis along which to repeat values. The negative numbers are 
interpreted "
+        "counting from the backward. By default, use the flattened input 
array, and "
+        "return a flat output array.");
+  }
+};  // struct RepeatAttrs
+
+/*! \brief Attributes used in tile operators */
+struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
+  Array<Integer> repeats;
+
+  TVM_DECLARE_ATTRS(TileAttrs, "relax.attrs.TileAttrs") {
+    TVM_ATTR_FIELD(repeats).describe("The number of repetitions of data along 
each axis.");
+  }
+};  // struct TileAttrs
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 61b1622a60..f49cb6b121 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -74,8 +74,55 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
   }
 };  // struct Conv2dAttrs
 
-/*! \brief Attributes used in max_pool2d operator */
-struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
+/*! \brief Attributes used in Conv2d operator */
+struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
+  Array<IntImm> strides;
+  Array<IntImm> padding;
+  Array<IntImm> output_padding;
+  Array<IntImm> dilation;
+  int groups;
+  String data_layout;
+  String kernel_layout;
+  String out_layout;
+  DataType out_dtype;
+
+  TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relax.attrs.Conv2DTransposeAttrs") {
+    TVM_ATTR_FIELD(strides).describe("Specifies the strides of the 
convolution.");
+    TVM_ATTR_FIELD(padding).describe(
+        "If padding is non-zero, then the input is implicitly zero-padded"
+        "Padding support both symmetric and asymmetric as"
+        "one int : same padding used on all sides"
+        "two int : bottom, right will use same padding as top, left"
+        "four int : padding width in the order of (top, left, bottom, right)");
+    TVM_ATTR_FIELD(output_padding).describe("Used to disambiguate the output 
shape.");
+    TVM_ATTR_FIELD(dilation).describe(
+        "Specifies the dilation rate to use for dilated convolution.");
+    TVM_ATTR_FIELD(groups).describe(
+        "Number of groups to split the input into for grouped convolution. The 
number of input and "
+        "output channels should be divisible by the number of groups.");
+    TVM_ATTR_FIELD(data_layout)
+        .describe(
+            "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+            "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+            "dimensions respectively. Convolution is applied on the 'H' and"
+            "'W' dimensions.");
+    TVM_ATTR_FIELD(kernel_layout)
+        .describe(
+            "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
+            "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, 
and width"
+            "dimensions respectively.");
+    TVM_ATTR_FIELD(out_layout)
+        .describe(
+            "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
+            "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+            "dimensions respectively. Default to be same as input layout.");
+    TVM_ATTR_FIELD(out_dtype).describe(
+        "Output data type, set to explicit type under mixed precision 
setting");
+  }
+};  // struct Conv2DTransposeAttrs
+
+/*! \brief Attributes used in max_pool2d and avg_pool2d operator */
+struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
   Array<IntImm> pool_size;
   Array<IntImm> strides;
   Array<IntImm> padding;
@@ -84,7 +131,7 @@ struct MaxPool2DAttrs : public 
tvm::AttrsNode<MaxPool2DAttrs> {
   String layout;
   String out_layout;
 
-  TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relax.attrs.MaxPool2DAttrs") {
+  TVM_DECLARE_ATTRS(Pool2DAttrs, "relax.attrs.Pool2DAttrs") {
     TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
     TVM_ATTR_FIELD(strides).describe("Specifies the strides of the 
convolution.");
     TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the 
convolution.");
@@ -109,7 +156,7 @@ struct MaxPool2DAttrs : public 
tvm::AttrsNode<MaxPool2DAttrs> {
             "dimensions respectively. Pooling is applied on the 'H' and"
             "'W' dimensions.");
   }
-};  // struct MaxPool2dAttrs
+};  // struct Pool2dAttrs
 
 /*! \brief Attributes for 2d adaptive pool operator */
 struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index 25bf525191..c59ab793e0 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -314,3 +314,77 @@ def collapse_sum_to(data: Expr, shape: 
Union[Tuple[PrimExprLike], Expr]) -> Expr
     if isinstance(shape, (tuple, list)):
         shape = ShapeExpr(shape)
     return _ffi_api.collapse_sum_to(data, shape)  # type: ignore
+
+
+def repeat(data: Expr, repeats: int, axis: Optional[int] = None) -> Expr:
+    """Repeats elements of an array.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input tensor.
+
+    repeats : int
+        The number of repetitions.
+
+    axis: Optional[int]
+        The axis along which to repeat values. The negative numbers are 
interpreted
+        counting from the backward. By default, use the flattened input array, 
and
+        return a flat output array.
+
+    Returns
+    -------
+    ret : relax.Expr
+        The computed result.
+
+    Examples
+    --------
+    .. code-block:: python
+        x = R.const([[1, 2], [3, 4]])
+        lv1 = R.repeat(x, repeats=2) # lv1 == [1, 1, 2, 2, 3, 3, 4, 4]
+        lv2 = R.repeat(x, repeats=2, axis=1) # lv2 == [[1., 1., 2., 2.],
+                                             #         [3., 3., 4., 4.]]
+    """
+    return _ffi_api.repeat(data, repeats, axis)  # type: ignore
+
+
+def tile(data: Expr, repeats: Union[int, Tuple[int], List[int]]) -> Expr:
+    """Construct an array by repeating data the number of times given by 
repeats.
+
+    If repeats has length l, and data has dimension d, the result will have 
dimension of max(l, d).
+
+    If d < l, data is promoted to be l-dimensional by prepending new axes. So 
a shape (3,) Tensor is
+    promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D 
replication. If this is not
+    the desired behavior, promote data to d-dimensions manually before calling 
this function.
+
+    If d > l, reps is promoted to length d by pre-pending 1's to it. Thus for 
a data of shape
+    (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2).
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    repeats : Union[int, Tuple[int], List[int]]
+        The number of repetitions of data along each axis.
+
+    Returns
+    -------
+    ret : relax.Expr
+        The computed result.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        x = R.const([[1, 2], [3, 4]])
+        lv1 = R.tile(x, reps=(2, 3)) # lv1 = [[1., 2., 1., 2., 1., 2.],
+                                     #        [3., 4., 3., 4., 3., 4.],
+                                     #        [1., 2., 1., 2., 1., 2.],
+                                     #        [3., 4., 3., 4., 3., 4.]]
+        lv2 = R.tile(x, reps=2) # lv2 = [[1., 2., 1., 2.],
+                                #        [3., 4., 3., 4.]]
+    """
+    if isinstance(repeats, int):
+        repeats = [repeats]
+    return _ffi_api.tile(data, repeats)  # type: ignore
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index bbb1268f1c..c774bbc926 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -121,6 +121,108 @@ def conv2d(
     )
 
 
+def conv2d_transpose(
+    data: Expr,
+    weight: Expr,
+    strides: Union[int, Tuple[int, int]] = (1, 1),
+    padding: Union[int, Tuple[int, ...]] = (0, 0),
+    output_padding: Union[int, Tuple[int, int]] = (0, 0),
+    dilation: Union[int, Tuple[int, int]] = (1, 1),
+    groups: int = 1,
+    data_layout: str = "NCHW",
+    kernel_layout: str = "IOHW",
+    out_layout: Optional[str] = None,
+    out_dtype: Optional[Union[str, DataType]] = None,
+) -> Expr:
+    r"""Two dimensional transposed convolution operator.
+
+    This operator is intended to be the gradient operator of conv2d. That 
means, if
+
+    `out = conv2d(data, weight, strides, padding, dilation)`,
+
+    The gradient w.r.t. data can be calculated as follows:
+
+    `data_grad = conv2d_transpose(out_grad, weight, strides, padding, 
output_padding, dilation)`,
+
+    where `output_padding` is a parameter used to determine the output shape.
+
+    The output shape can be explained in the simple case when `data_layout == 
"NCHW"` and
+    `kernel_layout == "IOHW"`. Suppose `data` has shape `(N, in_channel, in_h, 
in_w)`, `weight` has
+    shape `(in_channel, out_channel, weight_h, weight_w)`, we need to assure 
that
+    `in_channel % groups == 0`. The shape of the output will be
+    `(N, out_channel * groups, out_h, out_w)`, where
+
+    - `out_h = ((in_h - 1) * strides[0] + weight_h - 2 * padding[0] + 
output_padding[0])`
+    - `out_w = ((in_w - 1) * strides[1] + weight_w - 2 * padding[1] + 
output_padding[1])`
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    weight : relax.Expr
+        The weight expressions.
+
+    strides : Union[int, Tuple[int, int]]
+        The strides of convolution. It is required to have length either 1 or 
2.
+
+    padding : Union[int, Tuple[int, ...]]
+        The padding of convolution on both sides of inputs before convolution.
+        It is required to have length either 1, 2 or 4.
+
+    output_padding : Union[int, Tuple[int, ...]], optional
+        Used to disambiguate the output shape.
+
+    dilation : Union[int, Tuple[int, int]]
+        Specifies the dilation rate to be used for dilated convolution.
+        It is required to have length either 1 or 2.
+
+    groups : int
+        Number of groups to split the input into for grouped convolution.
+        The number of input and output channels should be divisible by the 
number of groups.
+
+    data_layout : str
+        Layout of the input.
+
+    kernel_layout : str
+        Layout of the weight.
+
+    out_layout : Optional[str]
+        Layout of the output. If not specified, it is the same as data_layout
+
+    out_dtype : Optional[Union[str, DataType]]
+        Specifies the output data type for mixed precision conv2d.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    # TODO: symbolic shape is not fully supported now
+    if isinstance(strides, int):
+        strides = (strides, strides)
+    if isinstance(dilation, int):
+        dilation = (dilation, dilation)
+    if isinstance(padding, int):
+        padding = (padding, padding, padding, padding)
+    if isinstance(output_padding, int):
+        output_padding = (output_padding, output_padding)
+
+    return _ffi_api.conv2d_transpose(  # type: ignore
+        data,
+        weight,
+        strides,
+        padding,
+        output_padding,
+        dilation,
+        groups,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
 def max_pool2d(
     data: Expr,
     pool_size: Union[int, Tuple[int, int]] = (1, 1),
@@ -134,8 +236,7 @@ def max_pool2d(
     r"""2D maximum pooling operator.
 
     This operator takes data as input and does 2D max value calculation
-    with in pool_size sized window by striding defined by stride
-
+    with in pool_size sized window by striding defined by stride.
 
     In the default case, where the data_layout is `NCHW`
     a data Tensor with shape `(batch_size, in_channels, height, width)`,
@@ -198,6 +299,83 @@ def max_pool2d(
     )
 
 
+def avg_pool2d(
+    data: Expr,
+    pool_size: Union[int, Tuple[int, int]] = (1, 1),
+    strides: Union[int, Tuple[int, int]] = (1, 1),
+    padding: Union[int, Tuple[int, ...]] = (0, 0),
+    dilation: Union[int, Tuple[int, int]] = (1, 1),
+    ceil_mode: bool = False,
+    layout: str = "NCHW",
+    out_layout: Optional[str] = None,
+) -> Expr:
+    r"""2D average pooling operator.
+
+    This operator takes data as input and does 2D avarage value calculation
+    with in pool_size sized window by striding defined by stride.
+
+    In the default case, where the data_layout is `NCHW`
+    a data Tensor with shape `(batch_size, in_channels, height, width)`,
+    to produce an output Tensor with the following rule:
+
+    with data of shape (b, c, h, w) and pool_size (kh, kw)
+
+    .. math::
+
+        \mbox{out}(b, c, y, x)  = \frac{1}{kh * kw} \sum_{m=0, \ldots, kh-1}
+            \sum_{n=0, \ldots, kw-1}
+            \mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x + 
n)
+
+    Padding is applied to data before the computation.
+    ceil_mode is used to take ceil or floor while computing out shape.
+    This operator accepts data layout specification.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    pool_size : Union[int, Tuple[int, int]]
+        The size of window for pooling. It is required to have length either 1 
or 2.
+
+    strides : Union[int, Tuple[int, int]]
+        The strides of pooling. It is required to have length either 1 or 2.
+
+    padding : Union[int, Tuple[int, ...]]
+        The padding for pooling. It is required to have length either 1, 2 or 
4.
+
+    dilation : Union[int, Tuple[int, int]]
+        The dilation of pooling. It is required to have length either 1 or 2.
+
+    ceil_mode : bool
+        A boolean indicating if use ceil or floor to compute the output shape.
+        By using ceil, every element in the input tensor will be covered by a 
sliding window.
+
+    layout : str
+        Layout of the input.
+
+    out_layout : Optional[str]
+        Layout of the output. If not specified, it is the same as data_layout
+
+    Returns
+    -------
+    result : Expr
+        The computed result.
+    """
+    if isinstance(pool_size, int):
+        pool_size = (pool_size, pool_size)
+    if isinstance(strides, int):
+        strides = (strides, strides)
+    if isinstance(dilation, int):
+        dilation = (dilation, dilation)
+    if isinstance(padding, int):
+        padding = (padding, padding, padding, padding)
+
+    return _ffi_api.avg_pool2d(  # type: ignore
+        data, pool_size, strides, padding, dilation, ceil_mode, layout, 
out_layout
+    )
+
+
 def adaptive_avg_pool2d(
     data: Expr,
     output_size: Optional[Union[int, Tuple[int, int]]] = None,
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 05531567bb..2d0fdd14b3 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -54,8 +54,13 @@ class Conv2DAttrs(Attrs):
     """Attributes for nn.conv2d"""
 
 
-@tvm._ffi.register_object("relax.attrs.MaxPool2DAttrs")
-class MaxPool2DAttrs(Attrs):
+@tvm._ffi.register_object("relax.attrs.Conv2DTransposeAttrs")
+class Conv2DTransposeAttrs(Attrs):
+    """Attributes for nn.conv2d_transpose"""
+
+
+@tvm._ffi.register_object("relax.attrs.Pool2DAttrs")
+class Pool2DAttrs(Attrs):
     """Attributes for nn.max_pool2d"""
 
 
@@ -127,3 +132,13 @@ class Resize2DAttrs(Attrs):
 @tvm._ffi.register_object("relax.attrs.ArgmaxArgminAttrs")
 class ArgmaxArgminAttrs(Attrs):
     """Attributes for argmax/argmin operator"""
+
+
+@tvm._ffi.register_object("relax.attrs.RepeatAttrs")
+class RepeatAttrs(Attrs):
+    """Attributes for repeat operator"""
+
+
+@tvm._ffi.register_object("relax.attrs.TileAttrs")
+class TileAttrs(Attrs):
+    """Attributes for tile operator"""
diff --git a/python/tvm/relax/transform/legalize_ops/common.py 
b/python/tvm/relax/transform/legalize_ops/common.py
index 4407b3fdf3..4ee9c6758f 100644
--- a/python/tvm/relax/transform/legalize_ops/common.py
+++ b/python/tvm/relax/transform/legalize_ops/common.py
@@ -57,7 +57,7 @@ def _try_convert_to_scalar_const(
         The expr to be checked and converted.
 
     Returns
-    --–----
+    -------
     ret : Union[Expr, FloatImm, IntImm, bool, float, int]
         Return a FloatImm or IntImm if the given expr is a scalar integer or 
float constant, and the
         python native flag is False. Or return the plain value of the constant 
in native python type
diff --git a/python/tvm/relax/transform/legalize_ops/index.py 
b/python/tvm/relax/transform/legalize_ops/index.py
index 9ee6b28130..eccccc7c6d 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -36,10 +36,8 @@ def _take(bb: BlockBuilder, call: Call) -> Expr:
 @register_legalize("relax.strided_slice")
 def _strided_slice(bb: BlockBuilder, call: Call) -> Expr:
     if not all(
-        [
-            isinstance(call.args[0].struct_info.shape.values[i.value], 
tir.IntImm)
-            for i in call.attrs.axes
-        ]
+        isinstance(call.args[0].struct_info.shape.values[i.value], tir.IntImm)
+        for i in call.attrs.axes
     ):
         logging.info(
             "Cases where an axis with symbolic length is sliced are not able "
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 5b992eff1d..7c67d5b26c 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -17,9 +17,11 @@
 # pylint: disable=invalid-name
 """Default legalization function for manipulate operators."""
 import logging
+from typing import Optional
 
 import tvm
 from tvm import topi, tir, relax, te
+from tvm.tir.expr import IntImm
 from ...block_builder import BlockBuilder
 from ...expr import Call, Expr, Var, Tuple, TupleGetItem
 from .common import TEFunc, LegalizeFunc, register_legalize
@@ -117,3 +119,26 @@ def _split(bb: BlockBuilder, call: Call) -> Expr:
 @register_legalize("relax.squeeze")
 def _squeeze(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(topi.squeeze, call.args[0], call.attrs.axis)
+
+
+@register_legalize("relax.repeat")
+def _repeat(bb: BlockBuilder, call: Call) -> Expr:
+    def te_repeat(data: te.Tensor, repeats: IntImm, axis: Optional[IntImm]):
+        if axis is None:
+            # flatten data
+            out_shape = data.shape[0]
+            for i in data.shape[1:]:
+                out_shape *= i
+            data = topi.reshape(data, (out_shape,))
+            axis = 0
+        # topi only receives int repeats and axis
+        return topi.repeat(data, int(repeats), int(axis))
+
+    return bb.call_te(
+        te_repeat, call.args[0], call.attrs.repeats, call.attrs.axis, 
primfunc_name_hint="repeat"
+    )
+
+
+@register_legalize("relax.tile")
+def _tile(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(topi.tile, call.args[0], call.attrs.repeats)
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index a61e0cd09e..bfc0544536 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -65,6 +65,41 @@ def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.nn.conv2d_transpose")
+def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr:
+    if call.attrs.out_layout != call.attrs.data_layout:
+        logging.info(
+            "TOPI conv2d_transpose does not support different input-output "
+            "layouts, and thus cannot be legalized by TOPI"
+        )
+        return call
+    if call.attrs.data_layout != "NCHW" or call.attrs.kernel_layout != "IOHW":
+        logging.info(
+            "TOPI conv2d_transpose does not support input layout other than 
NCHW, "
+            "and kernel layout other than IOHW, so cannot be legalized by TOPI"
+        )
+        return call
+    dilation = call.attrs.dilation
+    if len(dilation) != 2 or dilation[0] != 1 or dilation[1] != 1:
+        logging.info(
+            "TOPI conv2d_transpose does not support dilations other than 1, "
+            "and thus cannot be legalized by TOPI"
+        )
+        return call
+
+    return bb.call_te(
+        topi.nn.group_conv2d_transpose_nchw,
+        call.args[0],
+        call.args[1],
+        stride=call.attrs.strides,
+        padding=call.attrs.padding,
+        out_dtype=call.struct_info.dtype,
+        output_padding=call.attrs.output_padding,
+        groups=call.attrs.groups,
+        primfunc_name_hint="conv2d_transpose",
+    )
+
+
 @register_legalize("relax.nn.max_pool2d")
 def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr:
     if call.attrs.out_layout != call.attrs.layout:
@@ -88,6 +123,29 @@ def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.nn.avg_pool2d")
+def _nn_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr:
+    if call.attrs.out_layout != call.attrs.layout:
+        logging.info(
+            "TOPI avg_pool2d does not support different input-output "
+            "layouts, and thus cannot be legalized by TOPI"
+        )
+        return call
+
+    return bb.call_te(
+        topi.nn.pool2d,
+        call.args[0],
+        kernel=call.attrs.pool_size,
+        stride=call.attrs.strides,
+        dilation=call.attrs.dilation,
+        padding=call.attrs.padding,
+        pool_type="avg",
+        ceil_mode=call.attrs.ceil_mode,
+        layout=call.attrs.layout,
+        primfunc_name_hint="avg_pool2d",
+    )
+
+
 @register_legalize("relax.nn.adaptive_avg_pool2d")
 def _nn_adaptive_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr:
     if call.attrs.out_layout != call.attrs.layout:
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 9ef403181b..6166ae6330 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -93,6 +93,7 @@ from tvm.relax.op import (
     power,
     print,
     prod,
+    repeat,
     reshape,
     round,
     shape_of,
@@ -112,6 +113,7 @@ from tvm.relax.op import (
     subtract,
     tan,
     tanh,
+    tile,
     tril,
     triu,
     unique,
@@ -599,6 +601,7 @@ __all__ = [
     "prim_value",
     "print",
     "prod",
+    "repeat",
     "reshape",
     "round",
     "shape",
@@ -622,6 +625,7 @@ __all__ = [
     "take",
     "tan",
     "tanh",
+    "tile",
     "tril",
     "triu",
     "tuple",
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index a3ddd3e350..8dc3c9696f 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -142,5 +142,149 @@ TVM_REGISTER_OP("relax.nn.conv2d")
     .set_attrs_type<Conv2DAttrs>()
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv2d);
 
+/* relax.nn.conv2d_transpose */
+TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
+
+Expr conv2d_transpose(Expr data, Expr weight, Array<IntImm> strides, 
Array<IntImm> padding,
+                      Array<IntImm> output_padding, Array<IntImm> dilation, 
int groups,
+                      String data_layout, String kernel_layout, 
Optional<String> out_layout,
+                      DataType out_dtype) {
+  padding = GetCompletePadding2D(std::move(padding));
+  if (output_padding.size() == 1) {
+    output_padding.push_back(output_padding[0]);
+  }
+  if (strides.size() == 1) {
+    strides.push_back(strides[0]);
+  }
+  if (dilation.size() == 1) {
+    dilation.push_back(dilation[0]);
+  }
+
+  CHECK_GT(groups, 0) << "The number of groups in convolution is expected to 
be positive. However, "
+                         "the given number of groups is "
+                      << groups;
+  CHECK_EQ(output_padding.size(), 2) << "The input output_padding length is 
expected to be 4. "
+                                        "However, the given output_padding is "
+                                     << output_padding;
+  CHECK_EQ(strides.size(), 2)
+      << "The input strides length is expected to be 2. However, the given 
strides is " << strides;
+  CHECK_EQ(dilation.size(), 2)
+      << "The input dilation length is expected to be 2. However, the given 
dilation is "
+      << dilation;
+
+  auto attrs = make_object<Conv2DTransposeAttrs>();
+  attrs->strides = ConvertIntImmToInt64(strides);
+  attrs->padding = ConvertIntImmToInt64(padding);
+  attrs->output_padding = ConvertIntImmToInt64(output_padding);
+  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->groups = groups;
+  attrs->data_layout = data_layout;
+  attrs->kernel_layout = std::move(kernel_layout);
+  attrs->out_layout = std::move(out_layout.value_or(data_layout));
+  attrs->out_dtype = std::move(out_dtype);
+  const Op& op = Op::Get("relax.nn.conv2d_transpose");
+  return Call(op, {data, weight}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.conv2d_transpose").set_body_typed(conv2d_transpose);
+
+StructInfo InferStructInfoConv2dTranspose(const Call& call, const 
BlockBuilder& ctx) {
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+  TensorStructInfo data_sinfo = input_sinfo[0];
+  TensorStructInfo weight_sinfo = input_sinfo[1];
+
+  const auto* attrs = call->attrs.as<Conv2DTransposeAttrs>();
+  auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, 
attrs->data_layout,  //
+                                                    /*tgt_layout=*/"NCHW",     
     //
+                                                    /*tensor_name=*/"data");
+  auto [weight_layout, weight2IOHW] = CheckTensorLayout(call, ctx, 
attrs->kernel_layout,  //
+                                                        /*tgt_layout=*/"IOHW", 
           //
+                                                        
/*tensor_name=*/"kernel");
+  auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, 
attrs->out_layout,  //
+                                                  /*tgt_layout=*/"NCHW",       
  //
+                                                  /*tensor_name=*/"output");
+
+  Optional<ShapeExpr> data_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+  Optional<ShapeExpr> weight_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout);
+
+  DataType out_dtype = attrs->out_dtype.is_void()
+                           ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, 
weight_sinfo)
+                           : attrs->out_dtype;
+  if (!data_shape.defined() || !weight_shape.defined()) {
+    return TensorStructInfo(out_dtype, out_layout.ndim());
+  }
+
+  Array<PrimExpr> data_NCHW_shape = 
data2NCHW.ForwardShape(data_shape.value()->values);
+  Array<PrimExpr> weight_IOHW_shape = 
weight2IOHW.ForwardShape(weight_shape.value()->values);
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  PrimExpr input_channel_data = data_NCHW_shape[1];
+  PrimExpr input_channel_kernel = weight_IOHW_shape[0];
+  if (analyzer->CanProve(input_channel_data != input_channel_kernel)) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "Conv2dTranspose expects the channel size of the data should equal 
to the input channel "
+           "size of the weight. However, the data channel size is "
+        << input_channel_data << " while the weight input channel size is "
+        << input_channel_kernel);
+  } else if (!analyzer->CanProveEqual(input_channel_data, 
input_channel_kernel)) {
+    // Todo(relax-team): Trust the input shape at this moment, and revisit
+    // this condition with runtime shape check
+  }
+  if (analyzer->CanProve(floormod(input_channel_kernel, attrs->groups) != 0)) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Conv2dTranspose expects the number of input channels 
to be divisible by "
+                        "the number of groups. However, the number of input 
channels is "
+                     << input_channel_kernel << " while the number of groups 
is " << attrs->groups);
+  } else if (!analyzer->CanProveEqual(floormod(input_channel_kernel, 
attrs->groups), 0)) {
+    // Todo(relax-team): Trust the input shape at this moment, and revisit
+    // this condition with runtime shape check
+  }
+  if (analyzer->CanProve(attrs->output_padding[0]->value >= 
attrs->strides[0]->value ||
+                         attrs->output_padding[1]->value >= 
attrs->strides[1]->value)) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Conv2dTranspose expects the output padding less than 
the strides, but the "
+                        "output padding is"
+                     << attrs->output_padding << " while the strides are" << 
attrs->strides);
+  } else if (!analyzer->CanProve(attrs->output_padding[0]->value < 
attrs->strides[0]->value &&
+                                 attrs->output_padding[1]->value < 
attrs->strides[1]->value)) {
+    // Todo(relax-team): Trust the input padding at this moment, and revisit
+    // this condition with runtime shape check
+  }
+
+  PrimExpr input_h = data_NCHW_shape[2];
+  PrimExpr input_w = data_NCHW_shape[3];
+  PrimExpr kernel_h = weight_IOHW_shape[2];
+  PrimExpr kernel_w = weight_IOHW_shape[3];
+  PrimExpr padding_h = attrs->padding[0] + attrs->padding[2];
+  PrimExpr padding_w = attrs->padding[1] + attrs->padding[3];
+
+  std::vector<PrimExpr> out_NCHW_shape;
+  out_NCHW_shape.resize(4);
+  out_NCHW_shape[0] = data_NCHW_shape[0];
+  out_NCHW_shape[1] = weight_IOHW_shape[1] * attrs->groups;
+
+  PrimExpr out_h = (input_h - 1) * attrs->strides[0] - padding_h +
+                   attrs->dilation[0] * (kernel_h - 1) + 
attrs->output_padding[0] + 1;
+  PrimExpr out_w = (input_w - 1) * attrs->strides[1] - padding_w +
+                   attrs->dilation[1] * (kernel_w - 1) + 
attrs->output_padding[1] + 1;
+  out_NCHW_shape[2] = analyzer->Simplify(out_h);
+  out_NCHW_shape[3] = analyzer->Simplify(out_w);
+
+  Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+  return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
+}
+
+// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for 
conv2d_transpose
+// and unit test for mixed_precision
+TVM_REGISTER_OP("relax.nn.conv2d_transpose")
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("weight", "Tensor", "The weight tensor.")
+    .set_attrs_type<Conv2DTransposeAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoConv2dTranspose);
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h
index a65617b48d..7093c6a4d9 100644
--- a/src/relax/op/nn/convolution.h
+++ b/src/relax/op/nn/convolution.h
@@ -57,6 +57,17 @@ Expr conv2d(Expr data, Expr weight, Array<IntImm> strides, 
Array<IntImm> padding
             Array<IntImm> dilation, int groups, String data_layout, String 
kernel_layout,
             Optional<String> out_layout, DataType out_dtype);
 
+/*!
+ * \brief Two dimensional transposed convolution operator.
+ *
+ * This operator is intended to be the backward operator of conv2d. It can be 
used to calculate the
+ * gradient of the result of conv2d w.r.t. the input of conv2d.
+ */
+Expr conv2d_transpose(Expr data, Expr weight, Array<IntImm> strides, 
Array<IntImm> padding,
+                      Array<IntImm> output_padding, Array<IntImm> dilation, 
int groups,
+                      String data_layout, String kernel_layout, 
Optional<String> out_layout,
+                      DataType out_dtype);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index a4c1e6b17d..61001ce678 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -25,11 +25,11 @@
 namespace tvm {
 namespace relax {
 
-/* relax.nn.max_pool2d */
-TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs);
+/* relax.nn.max_pool2d and relax.nn.avg_pool2d */
+TVM_REGISTER_NODE_TYPE(Pool2DAttrs);
 
-Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
-                Array<IntImm> dilation, bool ceil_mode, String layout,
+Expr MakePool2d(String op_name, Expr data, Array<IntImm> pool_size, 
Array<IntImm> strides,
+                Array<IntImm> padding, Array<IntImm> dilation, bool ceil_mode, 
String layout,
                 Optional<String> out_layout) {
   padding = GetCompletePadding2D(std::move(padding));
   if (pool_size.size() == 1) {
@@ -51,24 +51,31 @@ Expr max_pool2d(Expr data, Array<IntImm> pool_size, 
Array<IntImm> strides, Array
       << "The input dilation length is expected to be 2. However, the given 
dilation is "
       << dilation;
 
-  auto attrs = make_object<MaxPool2DAttrs>();
-  attrs->pool_size = std::move(pool_size);
+  auto attrs = make_object<Pool2DAttrs>();
+  attrs->pool_size = ConvertIntImmToInt64(pool_size);
   attrs->strides = ConvertIntImmToInt64(strides);
   attrs->padding = ConvertIntImmToInt64(padding);
   attrs->dilation = ConvertIntImmToInt64(dilation);
   attrs->ceil_mode = ceil_mode;
   attrs->layout = layout;
   attrs->out_layout = out_layout.value_or(layout);
-  static const Op& op = Op::Get("relax.nn.max_pool2d");
+  const Op& op = Op::Get(op_name);
   return Call(op, {std::move(data)}, Attrs(attrs), {});
 }
 
+Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
+                Array<IntImm> dilation, bool ceil_mode, String layout,
+                Optional<String> out_layout) {
+  return MakePool2d("relax.nn.max_pool2d", data, pool_size, strides, padding, 
dilation, ceil_mode,
+                    layout, out_layout);
+}
+
 TVM_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d);
 
-StructInfo InferStructInfoMaxPool2D(const Call& call, const BlockBuilder& ctx) 
{
+StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) {
   TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
 
-  const auto* attrs = call->attrs.as<MaxPool2DAttrs>();
+  const auto* attrs = call->attrs.as<Pool2DAttrs>();
   auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout,  
//
                                                     /*tgt_layout=*/"NCHW",     
//
                                                     /*tensor_name=*/"data");
@@ -113,8 +120,23 @@ StructInfo InferStructInfoMaxPool2D(const Call& call, 
const BlockBuilder& ctx) {
 TVM_REGISTER_OP("relax.nn.max_pool2d")
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "The input tensor")
-    .set_attrs_type<MaxPool2DAttrs>()
-    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMaxPool2D);
+    .set_attrs_type<Pool2DAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D);
+
+Expr avg_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
+                Array<IntImm> dilation, bool ceil_mode, String layout,
+                Optional<String> out_layout) {
+  return MakePool2d("relax.nn.avg_pool2d", data, pool_size, strides, padding, 
dilation, ceil_mode,
+                    layout, out_layout);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool2d").set_body_typed(avg_pool2d);
+
+TVM_REGISTER_OP("relax.nn.avg_pool2d")
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor")
+    .set_attrs_type<Pool2DAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D);
 
 /* relax.nn.adaptive_avg_pool2d */
 TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs);
diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h
index 3c1792d21f..63d2e76772 100644
--- a/src/relax/op/nn/pooling.h
+++ b/src/relax/op/nn/pooling.h
@@ -36,6 +36,10 @@ namespace relax {
 Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
                 Array<IntImm> dilation, bool ceil_mode, String layout, 
Optional<String> out_layout);
 
+/*! \brief 2D average pooling operator. */
+Expr avg_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
+                Array<IntImm> dilation, bool ceil_mode, String layout, 
Optional<String> out_layout);
+
 /*! \brief 2D adaptive average pooling operator. */
 Expr adaptive_avg_pool2d(Expr data, Optional<Array<IntImm>> output_size, 
String layout,
                          Optional<String> out_layout);
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index e146a604af..c3a673f8fa 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -969,5 +969,131 @@ TVM_REGISTER_OP("relax.collapse_sum_to")
     .add_argument("shape", "Shape", "The shape to collapse to.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoCollapseSumTo);
 
+/* relax.repeat */
+TVM_REGISTER_NODE_TYPE(RepeatAttrs);
+
+Expr repeat(Expr data, int repeats, Optional<Integer> axis) {
+  auto attrs = make_object<RepeatAttrs>();
+  attrs->repeats = std::move(repeats);
+  attrs->axis = std::move(axis);
+
+  static const Op& op = Op::Get("relax.repeat");
+  return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.repeat").set_body_typed(repeat);
+
+StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) {
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  const auto* attrs = call->attrs.as<RepeatAttrs>();
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+
+  if (attrs->axis.defined() && !data_sinfo->IsUnknownNdim()) {
+    int axis = attrs->axis.value()->value;
+    int ndim = data_sinfo->ndim;
+    if (axis < -ndim || axis >= ndim) {
+      ctx->ReportFatal(
+          Diagnostic::Error(call)
+          << "Repeat requires the input axis belongs range "
+             "[-data.struct_info.ndim, data.struct_info.ndim - 1]. However, 
the input axis is "
+          << axis << ", while ndim is " << ndim);
+    }
+  }
+
+  if (data_shape == nullptr) {
+    if (attrs->axis.defined()) {
+      if (analyzer->CanProveEqual(attrs->repeats, 1)) {
+        // the shape does not changes
+        return data_sinfo;
+      } else {
+        return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
+      }
+    } else {
+      return TensorStructInfo(data_sinfo->dtype, 1);
+    }
+  }
+
+  if (!attrs->axis.defined()) {
+    PrimExpr new_shape =
+        analyzer->Simplify(ComputeShapeProduct(data_shape->values) * 
attrs->repeats);
+    return TensorStructInfo(ShapeExpr(Array<PrimExpr>({new_shape})), 
data_sinfo->dtype);
+  }
+
+  int axis = NormalizeAxis(call, ctx, data_sinfo->ndim, 
attrs->axis.value()->value);
+  auto shape_array = data_shape->values;
+  shape_array.Set(axis, analyzer->Simplify(shape_array[axis] * 
attrs->repeats));
+  return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype);
+}
+
+// TODO(relax-team): implement FRelaxInferLayout for repeat
+TVM_REGISTER_OP("relax.repeat")
+    .set_attrs_type<RepeatAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoRepeat);
+
+/* relax.tile */
+TVM_REGISTER_NODE_TYPE(TileAttrs);
+
+Expr tile(Expr data, Array<Integer> repeats) {
+  auto attrs = make_object<TileAttrs>();
+  attrs->repeats = std::move(repeats);
+
+  static const Op& op = Op::Get("relax.tile");
+  return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.tile").set_body_typed(tile);
+
+StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) {
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  const auto* attrs = call->attrs.as<TileAttrs>();
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+  int l = attrs->repeats.size();
+  int ndim = data_sinfo->ndim;
+
+  if (data_shape == nullptr) {
+    if (data_sinfo->IsUnknownNdim()) {
+      return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
+    }
+    if (l > ndim) {
+      return TensorStructInfo(data_sinfo->dtype, l);
+    } else {
+      for (auto i : attrs->repeats) {
+        if (!analyzer->CanProveEqual(i, 1)) {
+          return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
+        }
+      }
+      // if control reaches here, the shape should not be changed
+      return data_sinfo;
+    }
+  }
+
+  int out_ndim = std::max(l, ndim);
+  int l_delta = out_ndim - l;
+  int ndim_delta = out_ndim - ndim;
+  Array<PrimExpr> out_shape;
+  for (int i = 0; i < out_ndim; ++i) {
+    if (i < l_delta) {
+      out_shape.push_back(data_shape->values[i - ndim_delta]);
+    } else if (i < ndim_delta) {
+      out_shape.push_back(attrs->repeats[i - l_delta]);
+    } else {
+      out_shape.push_back(
+          analyzer->Simplify(data_shape->values[i - ndim_delta] * 
attrs->repeats[i - l_delta]));
+    }
+  }
+  return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
+}
+
+// TODO(relax-team): implement FRelaxInferLayout for tile
+TVM_REGISTER_OP("relax.tile")
+    .set_attrs_type<TileAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTile);
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 95e29a3dce..fb75664a1d 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -133,6 +133,33 @@ Expr collapse_sum_like(Expr data, Expr collapse_target);
  */
 Expr collapse_sum_to(Expr data, Expr shape);
 
+/*!
+ * \brief Repeats elements of an array.
+ * \param data The input tensor.
+ * \param repeats The number of repetitions.
+ * \param axis The axis along which to repeat values. The negative numbers are 
interpreted counting
+ * from the backward. By default, use the flattened input array, and return a 
flat output array.
+ * \return The computed result.
+ */
+Expr repeat(Expr data, int repeats, Optional<Integer> axis = NullOpt);
+
+/*!
+ * \brief Construct an array by repeating data the number of times given by 
reps.
+ *
+ * If reps has length l, and data has dimension d, the result will have 
dimension of max(l, d).
+ *
+ * If d < l, data is promoted to be l-dimensional by prepending new axes. So a 
shape (3,) Tensor is
+ * promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D 
replication. If this is not
+ * the desired behavior, promote data to d-dimensions manually before calling 
this function.
+ *
+ * If d > l, reps is promoted to length d by pre-pending 1's to it. Thus for a 
data of shape
+ * (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2).
+ * \param data The input tensor.
+ * \param repeats The number of repetitions of data along each axis.
+ * \return The computed result.
+ */
+Expr tile(Expr data, Array<Integer> repeats);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_op_manipulate.py 
b/tests/python/relax/test_op_manipulate.py
index abb414b472..e1f550cc38 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -21,6 +21,7 @@ from tvm import relax, tir
 from tvm import TVMError
 from tvm.ir import Op
 from tvm.script import relax as R
+from tvm.tir.expr import FloatImm, IntImm
 
 
 def test_op_correctness():
@@ -32,6 +33,8 @@ def test_op_correctness():
     assert relax.op.permute_dims(x).op == Op.get("relax.permute_dims")
     assert relax.op.reshape(x, (4, 5, 3)).op == Op.get("relax.reshape")
     assert relax.op.split(x, indices_or_sections=1).op == Op.get("relax.split")
+    assert relax.op.tile(x, (2, 2, 2)).op == Op.get("relax.tile")
+    assert relax.op.repeat(x, 2, 0).op == Op.get("relax.repeat")
     assert relax.op.squeeze(x).op == Op.get("relax.squeeze")
     assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c, 
a)).op == Op.get(
         "relax.layout_transform"
@@ -2704,5 +2707,248 @@ def 
test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var():
     )
 
 
+def test_repeat_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 10, 4)))
+    x4 = relax.Var("x", R.Tensor(ndim=3))
+    x5 = relax.Var("x", R.Tensor())
+
+    _check_inference(
+        bb,
+        relax.op.repeat(x0, 2, axis=0),
+        relax.TensorStructInfo((4, 10, 4), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.repeat(x0, 2, axis=-2),
+        relax.TensorStructInfo((2, 20, 4), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.repeat(x0, 2),
+        relax.TensorStructInfo((160,), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.repeat(x1, 2, axis=0),
+        relax.TensorStructInfo(dtype="float32", ndim=3),
+    )
+    _check_inference(
+        bb,
+        relax.op.repeat(x1, 2),
+        relax.TensorStructInfo(dtype="float32", ndim=1),
+    )
+    _check_inference(bb, relax.op.repeat(x2, 2, axis=0), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.repeat(x2, 2), 
relax.TensorStructInfo(dtype="float32", ndim=1))
+    _check_inference(
+        bb,
+        relax.op.repeat(x3, 2, axis=0),
+        relax.TensorStructInfo((4, 10, 4), dtype=""),
+    )
+    _check_inference(bb, relax.op.repeat(x4, 2, axis=0), 
relax.TensorStructInfo(dtype="", ndim=3))
+    _check_inference(bb, relax.op.repeat(x5, 2, axis=0), 
relax.TensorStructInfo(dtype=""))
+
+
+def test_repeat_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    c = tir.Var("c", "int64")
+    x = relax.Var("x", R.Tensor((a, b, c), "float32"))
+
+    _check_inference(bb, relax.op.repeat(x, 2, 0), relax.TensorStructInfo((a * 
2, b, c), "float32"))
+    _check_inference(
+        bb,
+        relax.op.repeat(x, 2, -1),
+        relax.TensorStructInfo((a, b, c * 2), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.repeat(x, 2),
+        relax.TensorStructInfo((a * b * c * 2,), "float32"),
+    )
+
+
+def test_repeat_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8"))
+
+    _check_inference(bb, relax.op.repeat(x0, 2, 0), relax.TensorStructInfo((4, 
3, 4), "float16"))
+    _check_inference(bb, relax.op.repeat(x1, 2, 0), relax.TensorStructInfo((4, 
3, 4), "int8"))
+
+
+def test_repeat_infer_struct_info_axis_out_of_range():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.repeat(x0, 2, 3))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.repeat(x0, 2, -4))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.repeat(x1, 2, 3))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.repeat(x1, 2, -4))
+    # okay
+    bb.normalize(relax.op.repeat(x2, 2, 3))
+    bb.normalize(relax.op.repeat(x2, 2, -4))
+
+
+def test_repeat_return_data_sinfo():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32"))
+
+    _check_inference(bb, relax.op.repeat(x0, 1, 0), x0.struct_info)
+    _check_inference(bb, relax.op.repeat(x0, 1, -1), x0.struct_info)
+    _check_inference(bb, relax.op.repeat(x1, 1, 0), x1.struct_info)
+    _check_inference(bb, relax.op.repeat(x2, 1, 0), x2.struct_info)
+
+
+def test_repeat_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), 
"float32")))
+    x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    r1 = tir.Var("r", "float32")
+    r2 = tir.StringImm("abc")
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.repeat(x0, 2))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.repeat(x1, 2))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.repeat(x2, 1.5))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.repeat(x2, r1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.repeat(x2, r2))
+
+
+def test_tile_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 10, 4)))
+    x4 = relax.Var("x", R.Tensor(ndim=3))
+    x5 = relax.Var("x", R.Tensor())
+
+    _check_inference(
+        bb,
+        relax.op.tile(x0, 2),
+        relax.TensorStructInfo((2, 10, 8), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.tile(x0, (3, 2)),
+        relax.TensorStructInfo((2, 30, 8), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.tile(x0, (4, 3, 2)),
+        relax.TensorStructInfo((8, 30, 8), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.tile(x0, (5, 4, 3, 2)),
+        relax.TensorStructInfo((5, 8, 30, 8), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.tile(x1, 2),
+        relax.TensorStructInfo(dtype="float32", ndim=3),
+    )
+    _check_inference(
+        bb,
+        relax.op.tile(x1, (5, 4, 3, 2)),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(bb, relax.op.tile(x2, (5, 4, 3, 2)), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(
+        bb,
+        relax.op.tile(x3, 2),
+        relax.TensorStructInfo((2, 10, 8), dtype=""),
+    )
+    _check_inference(
+        bb,
+        relax.op.tile(x3, (5, 4, 3, 2)),
+        relax.TensorStructInfo((5, 8, 30, 8), dtype=""),
+    )
+    _check_inference(bb, relax.op.tile(x4, 2), 
relax.TensorStructInfo(dtype="", ndim=3))
+    _check_inference(bb, relax.op.tile(x4, (5, 4, 3, 2)), 
relax.TensorStructInfo(dtype="", ndim=4))
+    _check_inference(bb, relax.op.tile(x5, (5, 4, 3, 2)), 
relax.TensorStructInfo(dtype=""))
+
+
+def test_tile_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    c = tir.Var("c", "int64")
+    x = relax.Var("x", R.Tensor((a, b, c), "float32"))
+
+    _check_inference(bb, relax.op.tile(x, 2), relax.TensorStructInfo((a, b, c 
* 2), "float32"))
+    _check_inference(
+        bb, relax.op.tile(x, (3, 2)), relax.TensorStructInfo((a, b * 3, c * 
2), "float32")
+    )
+    _check_inference(
+        bb, relax.op.tile(x, (4, 3, 2)), relax.TensorStructInfo((a * 4, b * 3, 
c * 2), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.tile(x, (5, 4, 3, 2)),
+        relax.TensorStructInfo((5, a * 4, b * 3, c * 2), "float32"),
+    )
+
+
+def test_tile_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8"))
+
+    _check_inference(bb, relax.op.tile(x0, (3, 2)), relax.TensorStructInfo((2, 
9, 8), "float16"))
+    _check_inference(bb, relax.op.tile(x1, (3, 2)), relax.TensorStructInfo((2, 
9, 8), "int8"))
+
+
+def test_tile_return_data_sinfo():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32"))
+
+    _check_inference(bb, relax.op.tile(x0, 1), x0.struct_info)
+    _check_inference(bb, relax.op.tile(x0, (1, 1)), x0.struct_info)
+    _check_inference(bb, relax.op.tile(x0, (1, 1, 1)), x0.struct_info)
+    _check_inference(bb, relax.op.tile(x1, 1), x1.struct_info)
+    _check_inference(bb, relax.op.tile(x2, 1), x2.struct_info)
+
+
+def test_tile_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), 
"float32")))
+    x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    r1 = tir.Var("a", "float32")
+    r2 = tir.StringImm("abc")
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.tile(x0, 2))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.tile(x1, 2))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.tile(x2, (2, 1.5, 2)))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.tile(x2, (2, r1)))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.tile(x2, r2))
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_op_nn_convolution.py 
b/tests/python/relax/test_op_nn_convolution.py
index 6533d43420..334f6977f7 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -27,6 +27,7 @@ def test_op_correctness():
     x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
     w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
     assert relax.op.nn.conv2d(x, w).op == Op.get("relax.nn.conv2d")
+    assert relax.op.nn.conv2d_transpose(x, w).op == 
Op.get("relax.nn.conv2d_transpose")
 
 
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
@@ -425,5 +426,389 @@ def test_conv2d_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.nn.conv2d(x1, w0))
 
 
+def test_conv2d_transpose_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 28, 28, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x3 = relax.Var("x", R.Tensor("float32"))
+    x4 = relax.Var("x", R.Tensor())
+    x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 16), "float32"))
+    w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+    w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
+    w2 = relax.Var("w", R.Tensor("float32", ndim=4))
+    w3 = relax.Var("w", R.Tensor("float32"))
+    w4 = relax.Var("w", R.Tensor((4, 48, 3, 3, 16), "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.conv2d_transpose(x0, w0), relax.TensorStructInfo((2, 
4, 30, 30), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float16"),
+        relax.TensorStructInfo((2, 4, 30, 30), "float16"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, padding=1),
+        relax.TensorStructInfo((2, 4, 28, 28), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, padding=[1, 2]),
+        relax.TensorStructInfo((2, 4, 28, 26), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, padding=[1, 2, 3, 4]),
+        relax.TensorStructInfo((2, 4, 26, 24), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, strides=3, output_padding=1),
+        relax.TensorStructInfo((2, 4, 85, 85), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, strides=3, output_padding=[2, 1]),
+        relax.TensorStructInfo((2, 4, 86, 85), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, strides=2),
+        relax.TensorStructInfo((2, 4, 57, 57), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, strides=(2, 3)),
+        relax.TensorStructInfo((2, 4, 57, 84), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, dilation=2),
+        relax.TensorStructInfo((2, 4, 32, 32), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, dilation=(2, 1)),
+        relax.TensorStructInfo((2, 4, 32, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x1, w0, data_layout="NHWC"),
+        relax.TensorStructInfo((2, 30, 30, 4), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, out_layout="NHWC"),
+        relax.TensorStructInfo((2, 30, 30, 4), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w1, kernel_layout="OIHW"),
+        relax.TensorStructInfo((2, 4, 30, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(
+            x5, w4, data_layout="NCHW16c", kernel_layout="IOHW16i", 
out_layout="NHWC16c"
+        ),
+        relax.TensorStructInfo((2, 30, 30, 3, 16), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d_transpose(x2, w0), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d_transpose(x3, w0), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d_transpose(x0, w2), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d_transpose(x0, w3), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d_transpose(x4, w0), 
relax.TensorStructInfo(dtype="", ndim=4)
+    )
+
+
+def test_conv2d_transpose_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    c16 = tir.Var("c16", "int64")
+    ih = tir.Var("ih", "int64")
+    iw = tir.Var("iw", "int64")
+    ki = tir.Var("ki", "int64")
+    ko = tir.Var("ko", "int64")
+    kh = tir.Var("kh", "int64")
+    kw = tir.Var("kw", "int64")
+    x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32"))
+    w0 = relax.Var("w", R.Tensor((ki, ko, kh, kw), "float32"))
+    w1 = relax.Var("w", R.Tensor((c, ko, kh, kw), "float32"))
+    w2 = relax.Var("w", R.Tensor((c, ko, kh, kw, c16), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0),
+        relax.TensorStructInfo((n, ko, ih + kh - 1, iw + kw - 1), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w1),
+        relax.TensorStructInfo((n, ko, ih + kh - 1, iw + kw - 1), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(
+            x1, w2, data_layout="NCHW16c", kernel_layout="IOHW16i", 
out_layout="NCHW"
+        ),
+        relax.TensorStructInfo((n, ko, ih + kh - 1, iw + kw - 1), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(
+            x0, w0, strides=(2, 2), padding=(1, 1), output_padding=(1, 0), 
dilation=(2, 2)
+        ),
+        relax.TensorStructInfo(
+            (n, ko, ih * 2 + kh * 2 - 4, iw * 2 + kw * 2 - 5),
+            "float32",
+        ),
+    )
+
+
+def test_conv2d_transpose_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+    s2 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+    s3 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32"))
+    w = relax.Var("w", relax.TensorStructInfo(s2, "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.conv2d_transpose(x0, w), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x1, w, data_layout="NCHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w, out_layout="NCHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x2, w),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+
+
+def test_conv2d_transpose_infer_struct_info_groups():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 8, 28, 28, 16), "float32"))
+    w0 = relax.Var("w", R.Tensor((128, 6, 3, 3), "float32"))
+    w1 = relax.Var("w", R.Tensor((16, 6, 3, 3, 8), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, groups=8),
+        relax.TensorStructInfo((2, 48, 30, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w1, kernel_layout="IOHW8i", groups=8),
+        relax.TensorStructInfo((2, 48, 30, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x1, w0, data_layout="NCHW16c", groups=8),
+        relax.TensorStructInfo((2, 3, 30, 30, 16), "float32"),
+    )
+
+
+def test_conv2d_transpose_infer_struct_info_symbolic_groups():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    ic = tir.Var("c", "int64")
+    oc = tir.Var("oc", "int64")
+    x = relax.Var("x", R.Tensor((n, ic * 4, 28, 28), "float32"))
+    w0 = relax.Var("w", R.Tensor((ic, oc, 3, 3), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x, w0, groups=4),
+        relax.TensorStructInfo((n, oc * 4, 30, 30), "float32"),
+    )
+
+
+def test_conv2d_transpose_infer_struct_info_input_channel_group_incompatible():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    ic = tir.Var("c", "int64")
+    oc = tir.Var("oc", "int64")
+    x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32"))
+    w0 = relax.Var("w", R.Tensor((128, 20, 3, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, ic, 28, 28), "float32"))
+    w1 = relax.Var("w", R.Tensor((ic - 1, oc, 3, 3), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x0, w0, groups=6))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x1, w1, groups=6))
+
+
+def test_conv2d_transpose_non_positive_group():
+    x = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32"))
+    w = relax.Var("w", R.Tensor((128, 16, 3, 3), "float32"))
+
+    with pytest.raises(TVMError):
+        relax.op.nn.conv2d_transpose(x, w, groups=0)
+    with pytest.raises(TVMError):
+        relax.op.nn.conv2d_transpose(x, w, groups=-2)
+
+
+def test_conv2d_transpose_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16"))
+    w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float64"))
+    w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float64"))
+    x2 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8"))
+    w2 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8"))
+    x3 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int32"))
+    w3 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int32"))
+
+    _check_inference(
+        bb, relax.op.nn.conv2d_transpose(x0, w0), relax.TensorStructInfo((2, 
4, 30, 30), "float16")
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d_transpose(x1, w1), relax.TensorStructInfo((2, 
4, 30, 30), "float64")
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d_transpose(x2, w2), relax.TensorStructInfo((2, 
4, 30, 30), "int8")
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d_transpose(x3, w3), relax.TensorStructInfo((2, 
4, 30, 30), "int32")
+    )
+
+
+def test_conv2d_transpose_unequal_input_channel():
+    bb = relax.BlockBuilder()
+    ic = tir.Var("ic", "int64")
+    x0 = relax.Var("x", R.Tensor([2, 3, 28, 28], "float32"))
+    w0 = relax.Var("w", R.Tensor([4, 3, 3, 3], "float32"))
+    x1 = relax.Var("x", R.Tensor([2, ic, 28, 28], "float32"))
+    w1 = relax.Var("w", R.Tensor([ic + 2, 4, 3, 3], "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x0, w0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x1, w1))
+
+
+def test_conv2d_transpose_wrong_output_padding():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor([2, 3, 28, 28], "float32"))
+    w0 = relax.Var("w", R.Tensor([3, 4, 3, 3], "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x0, w0, strides=2, 
output_padding=2))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x0, w0, strides=(2, 2), 
output_padding=(2, 2)))
+
+
+def test_conv2d_transpose_stride_padding_dilation_int64():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    w = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+    conv2d_transpose = relax.op.nn.conv2d_transpose(
+        x, w, strides=(1, 1), padding=(1, 1), output_padding=(1, 2), 
dilation=(1, 1)
+    )
+
+    assert conv2d_transpose.attrs.strides[0].dtype == "int64"
+    assert conv2d_transpose.attrs.strides[1].dtype == "int64"
+    assert conv2d_transpose.attrs.padding[0].dtype == "int64"
+    assert conv2d_transpose.attrs.padding[1].dtype == "int64"
+    assert conv2d_transpose.attrs.padding[2].dtype == "int64"
+    assert conv2d_transpose.attrs.padding[3].dtype == "int64"
+    assert conv2d_transpose.attrs.output_padding[0].dtype == "int64"
+    assert conv2d_transpose.attrs.output_padding[1].dtype == "int64"
+    assert conv2d_transpose.attrs.dilation[0].dtype == "int64"
+    assert conv2d_transpose.attrs.dilation[1].dtype == "int64"
+
+
+def test_conv2d_transpose_wrong_strides_padding_dilation_length():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    w = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+    with pytest.raises(TVMError):
+        relax.op.nn.conv2d_transpose(x, w, strides=(1, 2, 3))
+    with pytest.raises(TVMError):
+        relax.op.nn.conv2d_transpose(x, w, padding=(1, 2, 3))
+    with pytest.raises(TVMError):
+        relax.op.nn.conv2d_transpose(x, w, output_padding=(1, 2, 3))
+    with pytest.raises(TVMError):
+        relax.op.nn.conv2d_transpose(x, w, dilation=(1, 2, 3))
+
+
+def test_conv2d_transpose_infer_struct_info_wrong_layout_string():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    w = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x, w, data_layout="IOHW"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x, w, kernel_layout="NHWC"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x, w, out_layout="OHWI"))
+
+
+def test_conv2d_transpose_dtype_mismatch():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    w = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x, w))
+
+
+def test_conv2d_transpose_wrong_input_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=3))
+    w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+    w1 = relax.Var("w", R.Tensor((3, 4, 6, 3, 3), "float32"))
+    w2 = relax.Var("w", R.Tensor("float32", ndim=6))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x0, w1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x0, w1, 
data_layout="NCHW16c"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x0, w2))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x1, w0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x2, w0))
+
+
+def test_conv2d_transpose_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28)))
+    w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+    w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((3, 4, 3, 3), 
"float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x0, w1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d_transpose(x1, w0))
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_op_nn_pooling.py 
b/tests/python/relax/test_op_nn_pooling.py
index 0eec5de21c..2bd7747f31 100644
--- a/tests/python/relax/test_op_nn_pooling.py
+++ b/tests/python/relax/test_op_nn_pooling.py
@@ -26,6 +26,7 @@ from tvm.script import relax as R
 def test_op_correctness():
     x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
     assert relax.op.nn.max_pool2d(x).op == Op.get("relax.nn.max_pool2d")
+    assert relax.op.nn.avg_pool2d(x).op == Op.get("relax.nn.avg_pool2d")
     assert relax.op.nn.adaptive_avg_pool2d(x).op == 
Op.get("relax.nn.adaptive_avg_pool2d")
 
 
@@ -203,7 +204,7 @@ def test_max_pool2d_infer_struct_info_more_input_dtype():
     )
 
 
-def test_conv2d_stride_padding_dilation_int64():
+def test_max_pool2d_stride_padding_dilation_int64():
     x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
     max_pool2d = relax.op.nn.max_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 
1), dilation=(1, 1))
 
@@ -259,6 +260,231 @@ def test_max_pool2d_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.nn.max_pool2d(x1))
 
 
+def test_avg_pool2d_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x3 = relax.Var("x", R.Tensor("float32"))
+    x4 = relax.Var("x", R.Tensor(ndim=4))
+    x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), 
"float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x0, pool_size=3),
+        relax.TensorStructInfo((2, 3, 30, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x0, pool_size=(5, 3)),
+        relax.TensorStructInfo((2, 3, 28, 30), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.nn.avg_pool2d(x0, padding=1), relax.TensorStructInfo((2, 
3, 34, 34), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x0, padding=[1, 2]),
+        relax.TensorStructInfo((2, 3, 34, 36), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x0, strides=2),
+        relax.TensorStructInfo((2, 3, 16, 16), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x0, dilation=2),
+        relax.TensorStructInfo((2, 3, 32, 32), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x1, layout="NHWC"),
+        relax.TensorStructInfo((2, 32, 32, 3), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x0, out_layout="NHWC"),
+        relax.TensorStructInfo((2, 32, 32, 3), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"),
+        relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.nn.avg_pool2d(x2), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.nn.avg_pool2d(x3), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(bb, relax.op.nn.avg_pool2d(x4), 
relax.TensorStructInfo(dtype="", ndim=4))
+    _check_inference(bb, relax.op.nn.avg_pool2d(x5), 
relax.TensorStructInfo(dtype="", ndim=4))
+
+
+def test_avg_pool2d_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    c16 = tir.Var("c16", "int64")
+    ih = tir.Var("ih", "int64")
+    iw = tir.Var("iw", "int64")
+    x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(
+            x0, pool_size=(3, 3), strides=(3, 3), padding=(2, 2), dilation=(2, 
2)
+        ),
+        relax.TensorStructInfo(
+            (
+                n,
+                c,
+                tvm.tir.floordiv(ih - 1, 3) + 1,
+                tvm.tir.floordiv(iw - 1, 3) + 1,
+            ),
+            "float32",
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x1, layout="NCHW16c", out_layout="NHWC"),
+        relax.TensorStructInfo((n, ih, iw, c * 16), "float32"),
+    )
+
+
+def test_avg_pool2d_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+    s2 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.avg_pool2d(x0), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x1, layout="NCHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x2),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+
+
+def test_avg_pool2d_infer_struct_info_ceil_mode():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x, pool_size=3, strides=2, ceil_mode=True),
+        relax.TensorStructInfo((2, 3, 16, 16), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True),
+        relax.TensorStructInfo((2, 3, 15, 16), "float32"),
+    )
+
+
+def test_avg_pool2d_infer_struct_info_ceil_mode_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    ih = tir.Var("ih", "int64")
+    iw = tir.Var("iw", "int64")
+    x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.avg_pool2d(
+            x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2, 
2), ceil_mode=True
+        ),
+        relax.TensorStructInfo((n, c, tvm.tir.floordiv(ih, 2), 
tvm.tir.floordiv(iw, 2)), "float32"),
+    )
+
+
+def test_avg_pool2d_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64"))
+    _check_inference(
+        bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), 
"float16")
+    )
+    _check_inference(bb, relax.op.nn.avg_pool2d(x1), 
relax.TensorStructInfo((2, 3, 32, 32), "int8"))
+    _check_inference(
+        bb, relax.op.nn.avg_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), 
"int64")
+    )
+
+
+def test_avg_pool2d_stride_padding_dilation_int64():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    avg_pool2d = relax.op.nn.avg_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 
1), dilation=(1, 1))
+
+    assert avg_pool2d.attrs.strides[0].dtype == "int64"
+    assert avg_pool2d.attrs.strides[1].dtype == "int64"
+    assert avg_pool2d.attrs.padding[0].dtype == "int64"
+    assert avg_pool2d.attrs.padding[1].dtype == "int64"
+    assert avg_pool2d.attrs.padding[2].dtype == "int64"
+    assert avg_pool2d.attrs.padding[3].dtype == "int64"
+    assert avg_pool2d.attrs.dilation[0].dtype == "int64"
+    assert avg_pool2d.attrs.dilation[1].dtype == "int64"
+
+
+def test_avg_pool2d_wrong_pool_size_strides_padding_dilation_length():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    with pytest.raises(TVMError):
+        relax.op.nn.avg_pool2d(x, pool_size=(1, 2, 3))
+    with pytest.raises(TVMError):
+        relax.op.nn.avg_pool2d(x, strides=(1, 2, 3))
+    with pytest.raises(TVMError):
+        relax.op.nn.avg_pool2d(x, padding=(1, 2, 3))
+    with pytest.raises(TVMError):
+        relax.op.nn.avg_pool2d(x, dilation=(1, 2, 3))
+
+
+def test_avg_pool2d_infer_struct_info_wrong_layout_string():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.avg_pool2d(x, layout="OIHW"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.avg_pool2d(x, out_layout="OHWI"))
+
+
+def test_avg_pool2d_wrong_input_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.avg_pool2d(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.avg_pool2d(x1))
+
+
+def test_avg_pool2d_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), 
"float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.avg_pool2d(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.avg_pool2d(x1))
+
+
 def test_adaptive_avg_pool2d_infer_struct_info():
     bb = relax.BlockBuilder()
     x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py 
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 7ae0eb359a..fbb1024d26 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -18,7 +18,7 @@
 import pytest
 import tvm
 from tvm.relax.transform import LegalizeOps
-from tvm.script import relax as R, tir as T
+from tvm.script import relax as R, tir as T, ir as I
 import tvm.testing
 
 
@@ -888,5 +888,202 @@ def test_collapse_sum_to_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_repeat():
+    # fmt: off
+    @I.ir_module
+    class Repeat:
+        @R.function
+        def main(x: R.Tensor((3, 2, 3), "float32")):
+            gv = R.repeat(x, 2, 0)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((6, 2, 
3), dtype="float32"):
+            gv = R.call_tir(repeat, (x,), out_sinfo=R.Tensor((6, 2, 3), 
dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def repeat(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), 
T.int64(3)), "float32"), T_repeat: T.Buffer((T.int64(6), T.int64(2), 
T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            for ax0, ax1, ax2 in T.grid(T.int64(6), T.int64(2), T.int64(3)):
+                with T.block("T_repeat"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(rxplaceholder[v_ax0 // T.int64(2), v_ax1, v_ax2])
+                    T.writes(T_repeat[v_ax0, v_ax1, v_ax2])
+                    T_repeat[v_ax0, v_ax1, v_ax2] = rxplaceholder[v_ax0 // 
T.int64(2), v_ax1, v_ax2]
+    # fmt: on
+
+    mod = LegalizeOps()(Repeat)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_repeat_no_axis():
+    # fmt: off
+    @I.ir_module
+    class Repeat:
+        @R.function
+        def main(x: R.Tensor((3, 2, 3), "float32")):
+            gv = R.repeat(x, 2)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((3, 2, 3), dtype="float32")
+        ) -> R.Tensor((36,), dtype="float32"):
+            gv = R.call_tir(repeat, (x,), out_sinfo=R.Tensor((36,), 
dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def repeat(
+            rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), 
"float32"),
+            T_repeat: T.Buffer((T.int64(36),), "float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            T_reshape = T.alloc_buffer((T.int64(18),))
+            for ax0 in range(T.int64(18)):
+                with T.block("T_reshape"):
+                    v_ax0 = T.axis.spatial(T.int64(18), ax0)
+                    T.reads(
+                        rxplaceholder[
+                            v_ax0 % T.int64(18) // T.int64(6),
+                            v_ax0 % T.int64(6) // T.int64(3),
+                            v_ax0 % T.int64(3),
+                        ]
+                    )
+                    T.writes(T_reshape[v_ax0])
+                    T_reshape[v_ax0] = rxplaceholder[
+                        v_ax0 % T.int64(18) // T.int64(6),
+                        v_ax0 % T.int64(6) // T.int64(3),
+                        v_ax0 % T.int64(3),
+                    ]
+            for ax0 in range(T.int64(36)):
+                with T.block("T_repeat"):
+                    v_ax0 = T.axis.spatial(T.int64(36), ax0)
+                    T.reads(T_reshape[v_ax0 // T.int64(2)])
+                    T.writes(T_repeat[v_ax0])
+                    T_repeat[v_ax0] = T_reshape[v_ax0 // T.int64(2)]
+    # fmt: on
+
+    mod = LegalizeOps()(Repeat)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_repeat_symbolic():
+    # fmt: off
+    @I.ir_module
+    class Repeat:
+        @R.function
+        def main(x: R.Tensor(("a", "b", "c"), "float32")):
+            gv = R.repeat(x, 2, 0)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def repeat(var_rxplaceholder: T.handle, var_T_repeat: T.handle):
+            T.func_attr({"tir.noalias": True})
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c))
+            T_repeat = T.match_buffer(var_T_repeat, (T.int64(2) * a, b, c))
+            # with T.block("root"):
+            for ax0, ax1, ax2 in T.grid(a * T.int64(2), b, c):
+                with T.block("T_repeat"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(rxplaceholder[v_ax0 // T.int64(2), v_ax1, v_ax2])
+                    T.writes(T_repeat[v_ax0, v_ax1, v_ax2])
+                    T_repeat[v_ax0, v_ax1, v_ax2] = rxplaceholder[v_ax0 // 
T.int64(2), v_ax1, v_ax2]
+
+        @R.function
+        def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> 
R.Tensor(("2 * a", "b", "c"), dtype="float32"):
+            a = T.Var("a", "int64")
+            b = T.Var("b", "int64")
+            c = T.Var("c", "int64")
+            gv = R.call_tir(repeat, (x,), out_sinfo=R.Tensor((2 * a, b, c), 
dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Repeat)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_tile():
+    # fmt: off
+    @I.ir_module
+    class Tile:
+        @R.function
+        def main(x: R.Tensor((3, 2, 3), "float32")):
+            gv = R.tile(x, (2, 1, 2, 3))
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tile(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), 
"float32"), T_tile: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(9)), 
"float32")):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), 
T.int64(4), T.int64(9)):
+                with T.block("T_tile"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[v_ax1 % T.int64(3), v_ax2 % 
T.int64(2), v_ax3 % T.int64(3)])
+                    T.writes(T_tile[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_tile[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax1 % 
T.int64(3), v_ax2 % T.int64(2), v_ax3 % T.int64(3)]
+
+        @R.function
+        def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((2, 3, 
4, 9), dtype="float32"):
+            gv = R.call_tir(tile, (x,), out_sinfo=R.Tensor((2, 3, 4, 9), 
dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Tile)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_tile_symbolic():
+    # fmt: off
+    @I.ir_module
+    class Tile:
+        @R.function
+        def main(x: R.Tensor(("a", "b", "c"), "float32")):
+            gv = R.tile(x, (2, 1, 2, 3))
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def tile(var_rxplaceholder: T.handle, var_T_tile: T.handle):
+            T.func_attr({"tir.noalias": True})
+            a = T.int64()
+            b = T.int64()
+            c = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c))
+            T_tile = T.match_buffer(var_T_tile, (T.int64(2), a, b * 
T.int64(2), c * T.int64(3)))
+            # with T.block("root"):
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), a, b * T.int64(2), c 
* T.int64(3)):
+                with T.block("T_tile"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[v_ax1 % a, v_ax2 % b, v_ax3 % c])
+                    T.writes(T_tile[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_tile[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax1 % 
a, v_ax2 % b, v_ax3 % c]
+
+        @R.function
+        def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> 
R.Tensor((2, "a", "b * 2", "c * 3"), dtype="float32"):
+            a = T.Var("a", "int64")
+            b = T.Var("b", "int64")
+            c = T.Var("c", "int64")
+            gv = R.call_tir(tile, (x,), out_sinfo=R.Tensor((2, a, b * 2, c * 
3), dtype="float32"))
+            return gv
+    # fmt: on
+    mod = LegalizeOps()(Tile)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 8fb398f15d..90e89944f7 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -18,7 +18,7 @@
 import pytest
 import tvm
 from tvm.relax.transform import LegalizeOps
-from tvm.script import relax as R, tir as T
+from tvm.script import relax as R, tir as T, ir as I
 import tvm.testing
 
 
@@ -207,6 +207,188 @@ def test_conv2d_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_conv2d_transpose():
+    # fmt: off
+    @I.ir_module
+    class Conv2dTranspose:
+        @R.function
+        def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((128, 
16, 3, 3), "float32")):
+            gv = R.nn.conv2d_transpose(x, w, strides=(2, 3), padding=(1, 1), 
dilation=(1, 1), output_padding=(1, 2), groups=8)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 128, 28, 28), dtype="float32"), w: 
R.Tensor((128, 16, 3, 3), dtype="float32")) -> R.Tensor((2, 128, 56, 84), 
dtype="float32"):
+            gv = R.call_tir(conv2d_transpose, (x, w), out_sinfo=R.Tensor((2, 
128, 56, 84), dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def conv2d_transpose(rxplaceholder: T.Buffer((T.int64(2), 
T.int64(128), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: 
T.Buffer((T.int64(128), T.int64(16), T.int64(3), T.int64(3)), "float32"), 
compute: T.Buffer((T.int64(2), T.int64(128), T.int64(56), T.int64(84)), 
"float32")):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            data_dilate = T.alloc_buffer((T.int64(2), T.int64(128), 
T.int64(55), T.int64(82)))
+            data_pad = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(58), 
T.int64(86)))
+            kernel_transform = T.alloc_buffer((T.int64(16), T.int64(128), 
T.int64(3), T.int64(3)))
+            for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), 
T.int64(55), T.int64(82)):
+                with T.block("data_dilate"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(rxplaceholder[v_i0, v_i1, v_i2 // T.int64(2), v_i3 
// T.int64(3)])
+                    T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3])
+                    data_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2 
% T.int64(2) == T.int64(0) and v_i3 % T.int64(3) == T.int64(0), 
rxplaceholder[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(3)], T.float32(0))
+            for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), 
T.int64(58), T.int64(86)):
+                with T.block("data_pad"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - 
T.int64(1)])
+                    T.writes(data_pad[v_i0, v_i1, v_i2, v_i3])
+                    data_pad[v_i0, v_i1, v_i2, v_i3] = 
T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(56) and T.int64(1) <= v_i3 
and v_i3 < T.int64(83), data_dilate[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - 
T.int64(1)], T.float32(0))
+            for i, o, h, w in T.grid(T.int64(16), T.int64(128), T.int64(3), 
T.int64(3)):
+                with T.block("kernel_transform"):
+                    v_i, v_o, v_h, v_w = T.axis.remap("SSSS", [i, o, h, w])
+                    T.reads(rxplaceholder_1[v_o, v_i, T.int64(2) - v_h, 
T.int64(2) - v_w])
+                    T.writes(kernel_transform[v_i, v_o, v_h, v_w])
+                    kernel_transform[v_i, v_o, v_h, v_w] = 
rxplaceholder_1[v_o, v_i, T.int64(2) - v_h, T.int64(2) - v_w]
+            for b, c, h, w, dc, dh, dw in T.grid(T.int64(2), T.int64(128), 
T.int64(56), T.int64(84), T.int64(16), T.int64(3), T.int64(3)):
+                with T.block("compute"):
+                    v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = 
T.axis.remap("SSSSRRR", [b, c, h, w, dc, dh, dw])
+                    T.reads(data_pad[v_b, v_c // T.int64(16) * T.int64(16) + 
v_dc, v_h + v_dh, v_w + v_dw], kernel_transform[v_c % T.int64(16), v_c // 
T.int64(16) * T.int64(16) + v_dc, v_dh, v_dw])
+                    T.writes(compute[v_b, v_c, v_h, v_w])
+                    with T.init():
+                        compute[v_b, v_c, v_h, v_w] = T.float32(0)
+                    compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] 
+ data_pad[v_b, v_c // T.int64(16) * T.int64(16) + v_dc, v_h + v_dh, v_w + 
v_dw] * kernel_transform[v_c % T.int64(16), v_c // T.int64(16) * T.int64(16) + 
v_dc, v_dh, v_dw]
+    # fmt: on
+
+    mod = LegalizeOps()(Conv2dTranspose)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_conv2d_transpose_with_out_dtype():
+    # fmt: off
+    @tvm.script.ir_module
+    class Conv2dTranspose:
+        @R.function
+        def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 4, 3, 
3), "float32")):
+            gv = R.nn.conv2d_transpose(x, w, out_dtype="float16")
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 
4, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 30, 30), dtype="float16"):
+            gv = R.call_tir(conv2d_transpose, (x, w), out_sinfo=R.Tensor((2, 
4, 30, 30), dtype="float16"))
+            return gv
+
+        @T.prim_func
+        def conv2d_transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), 
T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), 
T.int64(4), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), 
T.int64(4), T.int64(30), T.int64(30)), "float16")):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            data_dilate = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28), 
T.int64(28)))
+            data_pad = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(32), 
T.int64(32)))
+            kernel_transform = T.alloc_buffer((T.int64(4), T.int64(3), 
T.int64(3), T.int64(3)))
+            for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), 
T.int64(28)):
+                with T.block("data_dilate"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3])
+                    T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3])
+                    data_dilate[v_i0, v_i1, v_i2, v_i3] = rxplaceholder[v_i0, 
v_i1, v_i2, v_i3]
+            for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), 
T.int64(32)):
+                with T.block("data_pad"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - 
T.int64(2)])
+                    T.writes(data_pad[v_i0, v_i1, v_i2, v_i3])
+                    data_pad[v_i0, v_i1, v_i2, v_i3] = 
T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(30) and T.int64(2) <= v_i3 
and v_i3 < T.int64(30), data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 - 
T.int64(2)], T.float32(0))
+            for o, i, h, w in T.grid(T.int64(4), T.int64(3), T.int64(3), 
T.int64(3)):
+                with T.block("kernel_transform"):
+                    v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h, w])
+                    T.reads(rxplaceholder_1[v_i, v_o, T.int64(2) - v_h, 
T.int64(2) - v_w])
+                    T.writes(kernel_transform[v_o, v_i, v_h, v_w])
+                    kernel_transform[v_o, v_i, v_h, v_w] = 
rxplaceholder_1[v_i, v_o, T.int64(2) - v_h, T.int64(2) - v_w]
+            for b, c, h, w, dc, dh, dw in T.grid(T.int64(2), T.int64(4), 
T.int64(30), T.int64(30), T.int64(3), T.int64(3), T.int64(3)):
+                with T.block("compute"):
+                    v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = 
T.axis.remap("SSSSRRR", [b, c, h, w, dc, dh, dw])
+                    T.reads(data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw], 
kernel_transform[v_c, v_dc, v_dh, v_dw])
+                    T.writes(compute[v_b, v_c, v_h, v_w])
+                    with T.init():
+                        compute[v_b, v_c, v_h, v_w] = T.float16(0)
+                    compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] 
+ T.Cast("float16", data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw]) * 
T.Cast("float16", kernel_transform[v_c, v_dc, v_dh, v_dw])
+    # fmt: on
+
+    mod = LegalizeOps()(Conv2dTranspose)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_conv2d_transpose_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class Conv2dTranspose:
+        @R.function
+        def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: 
R.Tensor(("f", "c", "kh", "kw"), "float32")):
+            gv = R.nn.conv2d_transpose(x, kernel, strides=(3, 3))
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor(("n", "c", "h", "w"), dtype="float32"), kernel: 
R.Tensor(("f", "c", "kh", "kw"), dtype="float32")) -> R.Tensor(("n", "c", "h * 
3 + kh - 3", "w * 3 + kw - 3"), dtype="float32"):
+            n = T.Var("n", "int64")
+            c = T.Var("c", "int64")
+            h = T.Var("h", "int64")
+            kh = T.Var("kh", "int64")
+            w = T.Var("w", "int64")
+            kw = T.Var("kw", "int64")
+            f = T.Var("f", "int64")
+            gv = R.call_tir(conv2d_transpose, (x, kernel), 
out_sinfo=R.Tensor((n, c, h * 3 + kh - 3, w * 3 + kw - 3), dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def conv2d_transpose(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_compute: T.handle):
+            T.func_attr({"tir.noalias": True})
+            n = T.var("int64")
+            c = T.var("int64")
+            h = T.var("int64")
+            w = T.var("int64")
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, h, w))
+            f = T.var("int64")
+            kh = T.var("int64")
+            kw = T.var("int64")
+            rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kh, 
kw))
+            compute = T.match_buffer(var_compute, (n, c, h * T.int64(3) + kh - 
T.int64(3), w * T.int64(3) + kw - T.int64(3)))
+            # with T.block("root"):
+            data_dilate = T.alloc_buffer((n, c, h * T.int64(3) - T.int64(2), w 
* T.int64(3) - T.int64(2)))
+            data_pad = T.alloc_buffer((n, c, h * T.int64(3) + kh * T.int64(2) 
- T.int64(4), w * T.int64(3) + kw * T.int64(2) - T.int64(4)))
+            kernel_transform = T.alloc_buffer((c, c, kh, kw))
+            for i0, i1, i2, i3 in T.grid(n, c, h * T.int64(3) - T.int64(2), w 
* T.int64(3) - T.int64(2)):
+                with T.block("data_dilate"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(rxplaceholder[v_i0, v_i1, v_i2 // T.int64(3), v_i3 
// T.int64(3)])
+                    T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3])
+                    data_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2 
% T.int64(3) == T.int64(0) and v_i3 % T.int64(3) == T.int64(0), 
rxplaceholder[v_i0, v_i1, v_i2 // T.int64(3), v_i3 // T.int64(3)], T.float32(0))
+            for i0, i1, i2, i3 in T.grid(n, c, h * T.int64(3) + kh * 
T.int64(2) - T.int64(4), w * T.int64(3) + kw * T.int64(2) - T.int64(4)):
+                with T.block("data_pad"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh, 
v_i3 + T.int64(1) - kw])
+                    T.writes(data_pad[v_i0, v_i1, v_i2, v_i3])
+                    data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(kh - 
T.int64(1) <= v_i2 and v_i2 < h * T.int64(3) + kh - T.int64(3) and kw - 
T.int64(1) <= v_i3 and v_i3 < w * T.int64(3) + kw - T.int64(3), 
data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh, v_i3 + T.int64(1) - kw], 
T.float32(0))
+            for o, i, h_1, w_1 in T.grid(c, c, kh, kw):
+                with T.block("kernel_transform"):
+                    v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h_1, w_1])
+                    T.reads(rxplaceholder_1[v_i, v_o, kh - v_h - T.int64(1), 
kw - v_w - T.int64(1)])
+                    T.writes(kernel_transform[v_o, v_i, v_h, v_w])
+                    kernel_transform[v_o, v_i, v_h, v_w] = 
rxplaceholder_1[v_i, v_o, kh - v_h - T.int64(1), kw - v_w - T.int64(1)]
+            for b, c_1, h_1, w_1, dc, dh, dw in T.grid(n, c, h * T.int64(3) + 
kh - T.int64(3), w * T.int64(3) + kw - T.int64(3), c, kh, kw):
+                with T.block("compute"):
+                    v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw = 
T.axis.remap("SSSSRRR", [b, c_1, h_1, w_1, dc, dh, dw])
+                    T.reads(data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw], 
kernel_transform[v_c, v_dc, v_dh, v_dw])
+                    T.writes(compute[v_b, v_c, v_h, v_w])
+                    with T.init():
+                        compute[v_b, v_c, v_h, v_w] = T.float32(0)
+                    compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w] 
+ data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw] * kernel_transform[v_c, v_dc, 
v_dh, v_dw]
+    # fmt: on
+
+    mod = LegalizeOps()(Conv2dTranspose)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_max_pool2d():
     # fmt: off
     @tvm.script.ir_module
@@ -345,6 +527,169 @@ def test_max_pool2d_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_avg_pool2d():
+    # fmt: off
+    @tvm.script.ir_module
+    class AvgPool2D:
+        @R.function
+        def main(x: R.Tensor((4, 112, 112, 6), "float32")) -> R.Tensor((4, 56, 
56, 6), "float32"):
+            gv: R.Tensor((4, 56, 56, 6), "float32") = R.nn.avg_pool2d(x, 
pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1], 
layout="NHWC")
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112), 
T.int64(112), T.int64(6)), "float32"), pool_avg: T.Buffer((T.int64(4), 
T.int64(56), T.int64(56), T.int64(6)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            pad_temp = T.alloc_buffer((T.int64(4), T.int64(114), T.int64(114), 
T.int64(6)))
+            pool_sum = T.alloc_buffer((T.int64(4), T.int64(56), T.int64(56), 
T.int64(6)))
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(114), 
T.int64(114), T.int64(6)):
+                with T.block("pad_temp"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[v_ax0, v_ax1 - T.int64(1), v_ax2 - 
T.int64(1), v_ax3])
+                    T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3])
+                    pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] = 
T.if_then_else(T.int64(1) <= v_ax1 and v_ax1 < T.int64(113) and T.int64(1) <= 
v_ax2 and v_ax2 < T.int64(113), rxplaceholder[v_ax0, v_ax1 - T.int64(1), v_ax2 
- T.int64(1), v_ax3], T.float32(0))
+            for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(4), 
T.int64(56), T.int64(56), T.int64(6), T.int64(3), T.int64(3)):
+                with T.block("pool_sum"):
+                    v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = 
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
+                    T.reads(pad_temp[v_ax0, v_ax1 * T.int64(2) + v_rv0, v_ax2 
* T.int64(2) + v_rv1, v_ax3])
+                    T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
+                    with T.init():
+                        pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
+                    pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, 
v_ax1, v_ax2, v_ax3] + pad_temp[v_ax0, v_ax1 * T.int64(2) + v_rv0, v_ax2 * 
T.int64(2) + v_rv1, v_ax3]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(56), 
T.int64(56), T.int64(6)):
+                with T.block("pool_avg"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
+                    pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, 
v_ax1, v_ax2, v_ax3] / T.Cast("float32", (T.min(T.int64(1), T.int64(112) - 
v_ax1 * T.int64(2)) + T.int64(2)) * (T.min(T.int64(1), T.int64(112) - v_ax2 * 
T.int64(2)) + T.int64(2)))
+
+        @R.function
+        def main(x: R.Tensor((4, 112, 112, 6), dtype="float32")) -> 
R.Tensor((4, 56, 56, 6), dtype="float32"):
+            gv = R.call_tir(avg_pool2d, (x,), out_sinfo=R.Tensor((4, 56, 56, 
6), dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(AvgPool2D)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_avg_pool2d_NCHW16c():
+    # fmt: off
+    @tvm.script.ir_module
+    class AvgPool2D:
+        @R.function
+        def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4, 
4, 110, 110, 16), "float32"):
+            gv: R.Tensor((4, 4, 110, 110, 16), "float32") = R.nn.avg_pool2d(x, 
pool_size=[3, 3], layout="NCHW16c")
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4), 
T.int64(112), T.int64(112), T.int64(16)), "float32"), pool_avg: 
T.Buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)), 
"float32")):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            pool_sum = T.alloc_buffer((T.int64(4), T.int64(4), T.int64(110), 
T.int64(110), T.int64(16)))
+            for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(T.int64(4), 
T.int64(4), T.int64(110), T.int64(110), T.int64(16), T.int64(3), T.int64(3)):
+                with T.block("pool_sum"):
+                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 = 
T.axis.remap("SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1])
+                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2 + v_rv0, v_ax3 + 
v_rv1, v_ax4])
+                    T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
+                    with T.init():
+                        pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = 
T.float32(0)
+                    pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = 
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] + rxplaceholder[v_ax0, v_ax1, v_ax2 
+ v_rv0, v_ax3 + v_rv1, v_ax4]
+            for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(4), T.int64(4), 
T.int64(110), T.int64(110), T.int64(16)):
+                with T.block("pool_avg"):
+                    v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", 
[ax0, ax1, ax2, ax3, ax4])
+                    T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
+                    T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
+                    T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
+                    pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = 
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] / T.Cast("float32", 
(T.min(T.int64(2), T.int64(111) - v_ax2) + T.int64(1)) * (T.min(T.int64(2), 
T.int64(111) - v_ax3) + T.int64(1)))
+
+        @R.function
+        def main(x: R.Tensor((4, 4, 112, 112, 16), dtype="float32")) -> 
R.Tensor((4, 4, 110, 110, 16), dtype="float32"):
+            gv = R.call_tir(avg_pool2d, (x,), out_sinfo=R.Tensor((4, 4, 110, 
110, 16), dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(AvgPool2D)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_avg_pool2d_ceil_mode():
+    # fmt: off
+    @tvm.script.ir_module
+    class AvgPool2D:
+        @R.function
+        def main(x: R.Tensor((4, 6, 112, 112), "float32")) -> R.Tensor((4, 6, 
38, 38), "float32"):
+            gv: R.Tensor((4, 6, 38, 38), "float32") = R.nn.avg_pool2d(x, 
pool_size=[3, 3], strides=[3, 3], dilation=[1, 1], padding=[1, 1, 1, 1], 
ceil_mode=True)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6), 
T.int64(112), T.int64(112)), "float32"), pool_avg: T.Buffer((T.int64(4), 
T.int64(6), T.int64(38), T.int64(38)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            pad_temp = T.alloc_buffer((T.int64(4), T.int64(6), T.int64(116), 
T.int64(116)))
+            pool_sum = T.alloc_buffer((T.int64(4), T.int64(6), T.int64(38), 
T.int64(38)))
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(6), 
T.int64(116), T.int64(116)):
+                with T.block("pad_temp"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2 - T.int64(1), 
v_ax3 - T.int64(1)])
+                    T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3])
+                    pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] = 
T.if_then_else(T.int64(1) <= v_ax2 and v_ax2 < T.int64(113) and T.int64(1) <= 
v_ax3 and v_ax3 < T.int64(113), rxplaceholder[v_ax0, v_ax1, v_ax2 - T.int64(1), 
v_ax3 - T.int64(1)], T.float32(0))
+            for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(4), T.int64(6), 
T.int64(38), T.int64(38), T.int64(3), T.int64(3)):
+                with T.block("pool_sum"):
+                    v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = 
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
+                    T.reads(pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(3) + v_rv0, 
v_ax3 * T.int64(3) + v_rv1])
+                    T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
+                    with T.init():
+                        pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
+                    pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, 
v_ax1, v_ax2, v_ax3] + pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(3) + v_rv0, v_ax3 
* T.int64(3) + v_rv1]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(6), 
T.int64(38), T.int64(38)):
+                with T.block("pool_avg"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
+                    pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, 
v_ax1, v_ax2, v_ax3] / T.Cast("float32", (T.min(T.int64(1), T.int64(112) - 
v_ax2 * T.int64(3)) + T.int64(2)) * (T.min(T.int64(1), T.int64(112) - v_ax3 * 
T.int64(3)) + T.int64(2)))
+
+        @R.function
+        def main(x: R.Tensor((4, 6, 112, 112), dtype="float32")) -> 
R.Tensor((4, 6, 38, 38), dtype="float32"):
+            gv = R.call_tir(avg_pool2d, (x,), out_sinfo=R.Tensor((4, 6, 38, 
38), dtype="float32"))
+            return gv
+
+    # fmt: on
+
+    mod = LegalizeOps()(AvgPool2D)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
[email protected]("TOPI pooling casts every shape value to i32.")
+def test_avg_pool2d_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class AvgPool2D:
+        @R.function
+        def main(dumb_param: R.Tensor(("kh", "kw")), x: R.Tensor(("n", "c", 
"h", "w"), "float32")) -> R.Tensor(("n", "c", "h - kh + 1", "w - kw + 1"), 
"float32"):
+            n = T.int64()
+            c = T.int64()
+            h = T.int64()
+            w = T.int64()
+            kh = T.int64()
+            kw = T.int64()
+            gv: R.Tensor((n, c, h - kh + 1, w - kw + 1), "float32") = 
R.nn.avg_pool2d(x, pool_size=[kh, kw])
+            return gv
+
+    # fmt: on
+
+    mod = LegalizeOps()(AvgPool2D)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_adaptive_avg_pool2d():
     # fmt: off
     @tvm.script.ir_module
diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py 
b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
index c1d0c90d34..10b786ee4a 100644
--- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py
+++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
@@ -343,5 +343,50 @@ def test_collapse_sum_to():
     _check(foo, bb.get()["foo"])
 
 
+def test_repeat():
+    @R.function
+    def foo(x: R.Tensor((2, 3, 4), "float32")):
+        gv = R.repeat(x, 3, 1)
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.repeat(x, 3, 1))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_repeat_no_axis():
+    @R.function
+    def foo(x: R.Tensor((2, 3, 4), "float32")):
+        gv = R.repeat(x, 3)
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.repeat(x, 3))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_tile():
+    @R.function
+    def foo(x: R.Tensor((2, 3, 4), "float32")):
+        gv = R.tile(x, (2, 3))
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.tile(x, (2, 3)))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py 
b/tests/python/relax/test_tvmscript_parser_op_nn.py
index c2bfa5b7a9..cfb454a578 100644
--- a/tests/python/relax/test_tvmscript_parser_op_nn.py
+++ b/tests/python/relax/test_tvmscript_parser_op_nn.py
@@ -53,6 +53,26 @@ def test_conv2d():
     _check(foo, bb.get()["foo"])
 
 
+def test_conv2d_transpose():
+    @R.function
+    def foo(
+        x: R.Tensor((2, 3, 228, 228), "float32"), w: R.Tensor((3, 16, 5, 5), 
"float32")
+    ) -> R.Tensor((2, 16, 232, 232), "float16"):
+        gv: R.Tensor((2, 16, 232, 232), "float16") = R.nn.conv2d_transpose(
+            x, w, out_dtype="float16"
+        )
+        return gv
+
+    x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float32"))
+    w = relax.Var("w", R.Tensor([3, 16, 5, 5], "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x, w]):
+        gv = bb.emit(relax.op.nn.conv2d_transpose(x, w, out_dtype="float16"))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
 def test_max_pool2d():
     @R.function
     def foo(
@@ -70,6 +90,23 @@ def test_max_pool2d():
     _check(foo, bb.get()["foo"])
 
 
+def test_avg_pool2d():
+    @R.function
+    def foo(
+        x: R.Tensor((1, 1, 32, 32), dtype="float32")
+    ) -> R.Tensor((1, 1, 30, 30), dtype="float32"):
+        gv: R.Tensor((1, 1, 30, 30), dtype="float32") = R.nn.avg_pool2d(x, 
pool_size=(3,))
+        return gv
+
+    x = relax.Var("x", R.Tensor([1, 1, 32, 32], "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.nn.avg_pool2d(x, pool_size=(3,)))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
 def test_adaptive_avg_pool2d():
     @R.function
     def foo(x: R.Tensor((2, 64, 8, 9), "float32")) -> R.Tensor((2, 64, 7, 7), 
"float32"):

Reply via email to