This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0fc5fd5118 [Unity][Op] Add repeat, tile, conv2d_transpose, avg_pool2d
(#14238)
0fc5fd5118 is described below
commit 0fc5fd5118cb22db9b747d8cf3bf6dd17ed52c3b
Author: Yixin Dong <[email protected]>
AuthorDate: Thu Mar 9 21:21:38 2023 +0800
[Unity][Op] Add repeat, tile, conv2d_transpose, avg_pool2d (#14238)
This PR adds a series of operators for relax:
- `repeat(data: Expr, repeats: int, axis: Optional[int] = None)`
- `tile(data: Expr, repeats: Union[int, Tuple[int], List[int]])`
- `relax.nn.conv2d_transpose`
- This operator is intended to find the gradient of conv2d w.r.t
its input. For details see [pytorch
document](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html),
and the document in `python/tvm/relax/op/nn/nn.py`.
- Now TOPI support of conv2d_transpose is limited. It does not
support dilations; layouts other than default; symbolic `output_padding`.
- `relax.nn.avg_pool2d`
---
include/tvm/relax/attrs/manipulate.h | 23 ++
include/tvm/relax/attrs/nn.h | 55 ++-
python/tvm/relax/op/manipulate.py | 74 ++++
python/tvm/relax/op/nn/nn.py | 182 +++++++++-
python/tvm/relax/op/op_attrs.py | 19 +-
python/tvm/relax/transform/legalize_ops/common.py | 2 +-
python/tvm/relax/transform/legalize_ops/index.py | 6 +-
.../tvm/relax/transform/legalize_ops/manipulate.py | 25 ++
python/tvm/relax/transform/legalize_ops/nn.py | 58 ++++
python/tvm/script/ir_builder/relax/ir.py | 4 +
src/relax/op/nn/convolution.cc | 144 ++++++++
src/relax/op/nn/convolution.h | 11 +
src/relax/op/nn/pooling.cc | 44 ++-
src/relax/op/nn/pooling.h | 4 +
src/relax/op/tensor/manipulate.cc | 126 +++++++
src/relax/op/tensor/manipulate.h | 27 ++
tests/python/relax/test_op_manipulate.py | 246 +++++++++++++
tests/python/relax/test_op_nn_convolution.py | 385 +++++++++++++++++++++
tests/python/relax/test_op_nn_pooling.py | 228 +++++++++++-
.../test_transform_legalize_ops_manipulate.py | 199 ++++++++++-
.../python/relax/test_transform_legalize_ops_nn.py | 347 ++++++++++++++++++-
.../relax/test_tvmscript_parser_op_manipulate.py | 45 +++
tests/python/relax/test_tvmscript_parser_op_nn.py | 37 ++
23 files changed, 2264 insertions(+), 27 deletions(-)
diff --git a/include/tvm/relax/attrs/manipulate.h
b/include/tvm/relax/attrs/manipulate.h
index bd6ae17bcf..4982daf7e4 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -102,6 +102,29 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
}
}; // struct SqueezeAttrs
+/*! \brief Attributes used in repeat operators */
+struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
+ int repeats;
+ Optional<Integer> axis;
+
+ TVM_DECLARE_ATTRS(RepeatAttrs, "relax.attrs.RepeatAttrs") {
+ TVM_ATTR_FIELD(repeats).describe("The number of repetitions.");
+ TVM_ATTR_FIELD(axis).describe(
+ "The axis along which to repeat values. The negative numbers are
interpreted "
+ "counting from the backward. By default, use the flattened input
array, and "
+ "return a flat output array.");
+ }
+}; // struct RepeatAttrs
+
+/*! \brief Attributes used in tile operators */
+struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
+ Array<Integer> repeats;
+
+ TVM_DECLARE_ATTRS(TileAttrs, "relax.attrs.TileAttrs") {
+ TVM_ATTR_FIELD(repeats).describe("The number of repetitions of data along
each axis.");
+ }
+}; // struct TileAttrs
+
} // namespace relax
} // namespace tvm
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 61b1622a60..f49cb6b121 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -74,8 +74,55 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
}
}; // struct Conv2dAttrs
-/*! \brief Attributes used in max_pool2d operator */
-struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
+/*! \brief Attributes used in Conv2d operator */
+struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
+ Array<IntImm> strides;
+ Array<IntImm> padding;
+ Array<IntImm> output_padding;
+ Array<IntImm> dilation;
+ int groups;
+ String data_layout;
+ String kernel_layout;
+ String out_layout;
+ DataType out_dtype;
+
+ TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relax.attrs.Conv2DTransposeAttrs") {
+ TVM_ATTR_FIELD(strides).describe("Specifies the strides of the
convolution.");
+ TVM_ATTR_FIELD(padding).describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "two int : bottom, right will use same padding as top, left"
+ "four int : padding width in the order of (top, left, bottom, right)");
+ TVM_ATTR_FIELD(output_padding).describe("Used to disambiguate the output
shape.");
+ TVM_ATTR_FIELD(dilation).describe(
+ "Specifies the dilation rate to use for dilated convolution.");
+ TVM_ATTR_FIELD(groups).describe(
+ "Number of groups to split the input into for grouped convolution. The
number of input and "
+ "output channels should be divisible by the number of groups.");
+ TVM_ATTR_FIELD(data_layout)
+ .describe(
+ "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Convolution is applied on the 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(kernel_layout)
+ .describe(
+ "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
+ "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height,
and width"
+ "dimensions respectively.");
+ TVM_ATTR_FIELD(out_layout)
+ .describe(
+ "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
+ "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+ "dimensions respectively. Default to be same as input layout.");
+ TVM_ATTR_FIELD(out_dtype).describe(
+ "Output data type, set to explicit type under mixed precision
setting");
+ }
+}; // struct Conv2DTransposeAttrs
+
+/*! \brief Attributes used in max_pool2d and avg_pool2d operator */
+struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
Array<IntImm> pool_size;
Array<IntImm> strides;
Array<IntImm> padding;
@@ -84,7 +131,7 @@ struct MaxPool2DAttrs : public
tvm::AttrsNode<MaxPool2DAttrs> {
String layout;
String out_layout;
- TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relax.attrs.MaxPool2DAttrs") {
+ TVM_DECLARE_ATTRS(Pool2DAttrs, "relax.attrs.Pool2DAttrs") {
TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
TVM_ATTR_FIELD(strides).describe("Specifies the strides of the
convolution.");
TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the
convolution.");
@@ -109,7 +156,7 @@ struct MaxPool2DAttrs : public
tvm::AttrsNode<MaxPool2DAttrs> {
"dimensions respectively. Pooling is applied on the 'H' and"
"'W' dimensions.");
}
-}; // struct MaxPool2dAttrs
+}; // struct Pool2dAttrs
/*! \brief Attributes for 2d adaptive pool operator */
struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index 25bf525191..c59ab793e0 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -314,3 +314,77 @@ def collapse_sum_to(data: Expr, shape:
Union[Tuple[PrimExprLike], Expr]) -> Expr
if isinstance(shape, (tuple, list)):
shape = ShapeExpr(shape)
return _ffi_api.collapse_sum_to(data, shape) # type: ignore
+
+
+def repeat(data: Expr, repeats: int, axis: Optional[int] = None) -> Expr:
+ """Repeats elements of an array.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input tensor.
+
+ repeats : int
+ The number of repetitions.
+
+ axis: Optional[int]
+ The axis along which to repeat values. The negative numbers are
interpreted
+ counting from the backward. By default, use the flattened input array,
and
+ return a flat output array.
+
+ Returns
+ -------
+ ret : relax.Expr
+ The computed result.
+
+ Examples
+ --------
+ .. code-block:: python
+ x = R.const([[1, 2], [3, 4]])
+ lv1 = R.repeat(x, repeats=2) # lv1 == [1, 1, 2, 2, 3, 3, 4, 4]
+ lv2 = R.repeat(x, repeats=2, axis=1) # lv2 == [[1., 1., 2., 2.],
+ # [3., 3., 4., 4.]]
+ """
+ return _ffi_api.repeat(data, repeats, axis) # type: ignore
+
+
+def tile(data: Expr, repeats: Union[int, Tuple[int], List[int]]) -> Expr:
+ """Construct an array by repeating data the number of times given by
repeats.
+
+ If repeats has length l, and data has dimension d, the result will have
dimension of max(l, d).
+
+ If d < l, data is promoted to be l-dimensional by prepending new axes. So
a shape (3,) Tensor is
+ promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D
replication. If this is not
+ the desired behavior, promote data to d-dimensions manually before calling
this function.
+
+ If d > l, reps is promoted to length d by pre-pending 1's to it. Thus for
a data of shape
+ (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2).
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ repeats : Union[int, Tuple[int], List[int]]
+ The number of repetitions of data along each axis.
+
+ Returns
+ -------
+ ret : relax.Expr
+ The computed result.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ x = R.const([[1, 2], [3, 4]])
+ lv1 = R.tile(x, reps=(2, 3)) # lv1 = [[1., 2., 1., 2., 1., 2.],
+ # [3., 4., 3., 4., 3., 4.],
+ # [1., 2., 1., 2., 1., 2.],
+ # [3., 4., 3., 4., 3., 4.]]
+ lv2 = R.tile(x, reps=2) # lv2 = [[1., 2., 1., 2.],
+ # [3., 4., 3., 4.]]
+ """
+ if isinstance(repeats, int):
+ repeats = [repeats]
+ return _ffi_api.tile(data, repeats) # type: ignore
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index bbb1268f1c..c774bbc926 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -121,6 +121,108 @@ def conv2d(
)
+def conv2d_transpose(
+ data: Expr,
+ weight: Expr,
+ strides: Union[int, Tuple[int, int]] = (1, 1),
+ padding: Union[int, Tuple[int, ...]] = (0, 0),
+ output_padding: Union[int, Tuple[int, int]] = (0, 0),
+ dilation: Union[int, Tuple[int, int]] = (1, 1),
+ groups: int = 1,
+ data_layout: str = "NCHW",
+ kernel_layout: str = "IOHW",
+ out_layout: Optional[str] = None,
+ out_dtype: Optional[Union[str, DataType]] = None,
+) -> Expr:
+ r"""Two dimensional transposed convolution operator.
+
+ This operator is intended to be the gradient operator of conv2d. That
means, if
+
+ `out = conv2d(data, weight, strides, padding, dilation)`,
+
+ The gradient w.r.t. data can be calculated as follows:
+
+ `data_grad = conv2d_transpose(out_grad, weight, strides, padding,
output_padding, dilation)`,
+
+ where `output_padding` is a parameter used to determine the output shape.
+
+ The output shape can be explained in the simple case when `data_layout ==
"NCHW"` and
+ `kernel_layout == "IOHW"`. Suppose `data` has shape `(N, in_channel, in_h,
in_w)`, `weight` has
+ shape `(in_channel, out_channel, weight_h, weight_w)`, we need to assure
that
+ `in_channel % groups == 0`. The shape of the output will be
+ `(N, out_channel * groups, out_h, out_w)`, where
+
+ - `out_h = ((in_h - 1) * strides[0] + weight_h - 2 * padding[0] +
output_padding[0])`
+ - `out_w = ((in_w - 1) * strides[1] + weight_w - 2 * padding[1] +
output_padding[1])`
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ weight : relax.Expr
+ The weight expressions.
+
+ strides : Union[int, Tuple[int, int]]
+ The strides of convolution. It is required to have length either 1 or
2.
+
+ padding : Union[int, Tuple[int, ...]]
+ The padding of convolution on both sides of inputs before convolution.
+ It is required to have length either 1, 2 or 4.
+
+ output_padding : Union[int, Tuple[int, ...]], optional
+ Used to disambiguate the output shape.
+
+ dilation : Union[int, Tuple[int, int]]
+ Specifies the dilation rate to be used for dilated convolution.
+ It is required to have length either 1 or 2.
+
+ groups : int
+ Number of groups to split the input into for grouped convolution.
+ The number of input and output channels should be divisible by the
number of groups.
+
+ data_layout : str
+ Layout of the input.
+
+ kernel_layout : str
+ Layout of the weight.
+
+ out_layout : Optional[str]
+ Layout of the output. If not specified, it is the same as data_layout
+
+ out_dtype : Optional[Union[str, DataType]]
+ Specifies the output data type for mixed precision conv2d.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ # TODO: symbolic shape is not fully supported now
+ if isinstance(strides, int):
+ strides = (strides, strides)
+ if isinstance(dilation, int):
+ dilation = (dilation, dilation)
+ if isinstance(padding, int):
+ padding = (padding, padding, padding, padding)
+ if isinstance(output_padding, int):
+ output_padding = (output_padding, output_padding)
+
+ return _ffi_api.conv2d_transpose( # type: ignore
+ data,
+ weight,
+ strides,
+ padding,
+ output_padding,
+ dilation,
+ groups,
+ data_layout,
+ kernel_layout,
+ out_layout,
+ out_dtype,
+ )
+
+
def max_pool2d(
data: Expr,
pool_size: Union[int, Tuple[int, int]] = (1, 1),
@@ -134,8 +236,7 @@ def max_pool2d(
r"""2D maximum pooling operator.
This operator takes data as input and does 2D max value calculation
- with in pool_size sized window by striding defined by stride
-
+ with in pool_size sized window by striding defined by stride.
In the default case, where the data_layout is `NCHW`
a data Tensor with shape `(batch_size, in_channels, height, width)`,
@@ -198,6 +299,83 @@ def max_pool2d(
)
+def avg_pool2d(
+ data: Expr,
+ pool_size: Union[int, Tuple[int, int]] = (1, 1),
+ strides: Union[int, Tuple[int, int]] = (1, 1),
+ padding: Union[int, Tuple[int, ...]] = (0, 0),
+ dilation: Union[int, Tuple[int, int]] = (1, 1),
+ ceil_mode: bool = False,
+ layout: str = "NCHW",
+ out_layout: Optional[str] = None,
+) -> Expr:
+ r"""2D average pooling operator.
+
+ This operator takes data as input and does 2D avarage value calculation
+ with in pool_size sized window by striding defined by stride.
+
+ In the default case, where the data_layout is `NCHW`
+ a data Tensor with shape `(batch_size, in_channels, height, width)`,
+ to produce an output Tensor with the following rule:
+
+ with data of shape (b, c, h, w) and pool_size (kh, kw)
+
+ .. math::
+
+ \mbox{out}(b, c, y, x) = \frac{1}{kh * kw} \sum_{m=0, \ldots, kh-1}
+ \sum_{n=0, \ldots, kw-1}
+ \mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x +
n)
+
+ Padding is applied to data before the computation.
+ ceil_mode is used to take ceil or floor while computing out shape.
+ This operator accepts data layout specification.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ pool_size : Union[int, Tuple[int, int]]
+ The size of window for pooling. It is required to have length either 1
or 2.
+
+ strides : Union[int, Tuple[int, int]]
+ The strides of pooling. It is required to have length either 1 or 2.
+
+ padding : Union[int, Tuple[int, ...]]
+ The padding for pooling. It is required to have length either 1, 2 or
4.
+
+ dilation : Union[int, Tuple[int, int]]
+ The dilation of pooling. It is required to have length either 1 or 2.
+
+ ceil_mode : bool
+ A boolean indicating if use ceil or floor to compute the output shape.
+ By using ceil, every element in the input tensor will be covered by a
sliding window.
+
+ layout : str
+ Layout of the input.
+
+ out_layout : Optional[str]
+ Layout of the output. If not specified, it is the same as data_layout
+
+ Returns
+ -------
+ result : Expr
+ The computed result.
+ """
+ if isinstance(pool_size, int):
+ pool_size = (pool_size, pool_size)
+ if isinstance(strides, int):
+ strides = (strides, strides)
+ if isinstance(dilation, int):
+ dilation = (dilation, dilation)
+ if isinstance(padding, int):
+ padding = (padding, padding, padding, padding)
+
+ return _ffi_api.avg_pool2d( # type: ignore
+ data, pool_size, strides, padding, dilation, ceil_mode, layout,
out_layout
+ )
+
+
def adaptive_avg_pool2d(
data: Expr,
output_size: Optional[Union[int, Tuple[int, int]]] = None,
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 05531567bb..2d0fdd14b3 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -54,8 +54,13 @@ class Conv2DAttrs(Attrs):
"""Attributes for nn.conv2d"""
-@tvm._ffi.register_object("relax.attrs.MaxPool2DAttrs")
-class MaxPool2DAttrs(Attrs):
+@tvm._ffi.register_object("relax.attrs.Conv2DTransposeAttrs")
+class Conv2DTransposeAttrs(Attrs):
+ """Attributes for nn.conv2d_transpose"""
+
+
+@tvm._ffi.register_object("relax.attrs.Pool2DAttrs")
+class Pool2DAttrs(Attrs):
"""Attributes for nn.max_pool2d"""
@@ -127,3 +132,13 @@ class Resize2DAttrs(Attrs):
@tvm._ffi.register_object("relax.attrs.ArgmaxArgminAttrs")
class ArgmaxArgminAttrs(Attrs):
"""Attributes for argmax/argmin operator"""
+
+
+@tvm._ffi.register_object("relax.attrs.RepeatAttrs")
+class RepeatAttrs(Attrs):
+ """Attributes for repeat operator"""
+
+
+@tvm._ffi.register_object("relax.attrs.TileAttrs")
+class TileAttrs(Attrs):
+ """Attributes for tile operator"""
diff --git a/python/tvm/relax/transform/legalize_ops/common.py
b/python/tvm/relax/transform/legalize_ops/common.py
index 4407b3fdf3..4ee9c6758f 100644
--- a/python/tvm/relax/transform/legalize_ops/common.py
+++ b/python/tvm/relax/transform/legalize_ops/common.py
@@ -57,7 +57,7 @@ def _try_convert_to_scalar_const(
The expr to be checked and converted.
Returns
- --–----
+ -------
ret : Union[Expr, FloatImm, IntImm, bool, float, int]
Return a FloatImm or IntImm if the given expr is a scalar integer or
float constant, and the
python native flag is False. Or return the plain value of the constant
in native python type
diff --git a/python/tvm/relax/transform/legalize_ops/index.py
b/python/tvm/relax/transform/legalize_ops/index.py
index 9ee6b28130..eccccc7c6d 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -36,10 +36,8 @@ def _take(bb: BlockBuilder, call: Call) -> Expr:
@register_legalize("relax.strided_slice")
def _strided_slice(bb: BlockBuilder, call: Call) -> Expr:
if not all(
- [
- isinstance(call.args[0].struct_info.shape.values[i.value],
tir.IntImm)
- for i in call.attrs.axes
- ]
+ isinstance(call.args[0].struct_info.shape.values[i.value], tir.IntImm)
+ for i in call.attrs.axes
):
logging.info(
"Cases where an axis with symbolic length is sliced are not able "
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 5b992eff1d..7c67d5b26c 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -17,9 +17,11 @@
# pylint: disable=invalid-name
"""Default legalization function for manipulate operators."""
import logging
+from typing import Optional
import tvm
from tvm import topi, tir, relax, te
+from tvm.tir.expr import IntImm
from ...block_builder import BlockBuilder
from ...expr import Call, Expr, Var, Tuple, TupleGetItem
from .common import TEFunc, LegalizeFunc, register_legalize
@@ -117,3 +119,26 @@ def _split(bb: BlockBuilder, call: Call) -> Expr:
@register_legalize("relax.squeeze")
def _squeeze(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(topi.squeeze, call.args[0], call.attrs.axis)
+
+
+@register_legalize("relax.repeat")
+def _repeat(bb: BlockBuilder, call: Call) -> Expr:
+ def te_repeat(data: te.Tensor, repeats: IntImm, axis: Optional[IntImm]):
+ if axis is None:
+ # flatten data
+ out_shape = data.shape[0]
+ for i in data.shape[1:]:
+ out_shape *= i
+ data = topi.reshape(data, (out_shape,))
+ axis = 0
+ # topi only receives int repeats and axis
+ return topi.repeat(data, int(repeats), int(axis))
+
+ return bb.call_te(
+ te_repeat, call.args[0], call.attrs.repeats, call.attrs.axis,
primfunc_name_hint="repeat"
+ )
+
+
+@register_legalize("relax.tile")
+def _tile(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(topi.tile, call.args[0], call.attrs.repeats)
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index a61e0cd09e..bfc0544536 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -65,6 +65,41 @@ def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.nn.conv2d_transpose")
+def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr:
+ if call.attrs.out_layout != call.attrs.data_layout:
+ logging.info(
+ "TOPI conv2d_transpose does not support different input-output "
+ "layouts, and thus cannot be legalized by TOPI"
+ )
+ return call
+ if call.attrs.data_layout != "NCHW" or call.attrs.kernel_layout != "IOHW":
+ logging.info(
+ "TOPI conv2d_transpose does not support input layout other than
NCHW, "
+ "and kernel layout other than IOHW, so cannot be legalized by TOPI"
+ )
+ return call
+ dilation = call.attrs.dilation
+ if len(dilation) != 2 or dilation[0] != 1 or dilation[1] != 1:
+ logging.info(
+ "TOPI conv2d_transpose does not support dilations other than 1, "
+ "and thus cannot be legalized by TOPI"
+ )
+ return call
+
+ return bb.call_te(
+ topi.nn.group_conv2d_transpose_nchw,
+ call.args[0],
+ call.args[1],
+ stride=call.attrs.strides,
+ padding=call.attrs.padding,
+ out_dtype=call.struct_info.dtype,
+ output_padding=call.attrs.output_padding,
+ groups=call.attrs.groups,
+ primfunc_name_hint="conv2d_transpose",
+ )
+
+
@register_legalize("relax.nn.max_pool2d")
def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.layout:
@@ -88,6 +123,29 @@ def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.nn.avg_pool2d")
+def _nn_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr:
+ if call.attrs.out_layout != call.attrs.layout:
+ logging.info(
+ "TOPI avg_pool2d does not support different input-output "
+ "layouts, and thus cannot be legalized by TOPI"
+ )
+ return call
+
+ return bb.call_te(
+ topi.nn.pool2d,
+ call.args[0],
+ kernel=call.attrs.pool_size,
+ stride=call.attrs.strides,
+ dilation=call.attrs.dilation,
+ padding=call.attrs.padding,
+ pool_type="avg",
+ ceil_mode=call.attrs.ceil_mode,
+ layout=call.attrs.layout,
+ primfunc_name_hint="avg_pool2d",
+ )
+
+
@register_legalize("relax.nn.adaptive_avg_pool2d")
def _nn_adaptive_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.layout:
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 9ef403181b..6166ae6330 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -93,6 +93,7 @@ from tvm.relax.op import (
power,
print,
prod,
+ repeat,
reshape,
round,
shape_of,
@@ -112,6 +113,7 @@ from tvm.relax.op import (
subtract,
tan,
tanh,
+ tile,
tril,
triu,
unique,
@@ -599,6 +601,7 @@ __all__ = [
"prim_value",
"print",
"prod",
+ "repeat",
"reshape",
"round",
"shape",
@@ -622,6 +625,7 @@ __all__ = [
"take",
"tan",
"tanh",
+ "tile",
"tril",
"triu",
"tuple",
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index a3ddd3e350..8dc3c9696f 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -142,5 +142,149 @@ TVM_REGISTER_OP("relax.nn.conv2d")
.set_attrs_type<Conv2DAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv2d);
+/* relax.nn.conv2d_transpose */
+TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
+
+Expr conv2d_transpose(Expr data, Expr weight, Array<IntImm> strides,
Array<IntImm> padding,
+ Array<IntImm> output_padding, Array<IntImm> dilation,
int groups,
+ String data_layout, String kernel_layout,
Optional<String> out_layout,
+ DataType out_dtype) {
+ padding = GetCompletePadding2D(std::move(padding));
+ if (output_padding.size() == 1) {
+ output_padding.push_back(output_padding[0]);
+ }
+ if (strides.size() == 1) {
+ strides.push_back(strides[0]);
+ }
+ if (dilation.size() == 1) {
+ dilation.push_back(dilation[0]);
+ }
+
+ CHECK_GT(groups, 0) << "The number of groups in convolution is expected to
be positive. However, "
+ "the given number of groups is "
+ << groups;
+ CHECK_EQ(output_padding.size(), 2) << "The input output_padding length is
expected to be 4. "
+ "However, the given output_padding is "
+ << output_padding;
+ CHECK_EQ(strides.size(), 2)
+ << "The input strides length is expected to be 2. However, the given
strides is " << strides;
+ CHECK_EQ(dilation.size(), 2)
+ << "The input dilation length is expected to be 2. However, the given
dilation is "
+ << dilation;
+
+ auto attrs = make_object<Conv2DTransposeAttrs>();
+ attrs->strides = ConvertIntImmToInt64(strides);
+ attrs->padding = ConvertIntImmToInt64(padding);
+ attrs->output_padding = ConvertIntImmToInt64(output_padding);
+ attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->groups = groups;
+ attrs->data_layout = data_layout;
+ attrs->kernel_layout = std::move(kernel_layout);
+ attrs->out_layout = std::move(out_layout.value_or(data_layout));
+ attrs->out_dtype = std::move(out_dtype);
+ const Op& op = Op::Get("relax.nn.conv2d_transpose");
+ return Call(op, {data, weight}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.conv2d_transpose").set_body_typed(conv2d_transpose);
+
+StructInfo InferStructInfoConv2dTranspose(const Call& call, const
BlockBuilder& ctx) {
+ Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+ TensorStructInfo data_sinfo = input_sinfo[0];
+ TensorStructInfo weight_sinfo = input_sinfo[1];
+
+ const auto* attrs = call->attrs.as<Conv2DTransposeAttrs>();
+ auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx,
attrs->data_layout, //
+ /*tgt_layout=*/"NCHW",
//
+ /*tensor_name=*/"data");
+ auto [weight_layout, weight2IOHW] = CheckTensorLayout(call, ctx,
attrs->kernel_layout, //
+ /*tgt_layout=*/"IOHW",
//
+
/*tensor_name=*/"kernel");
+ auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx,
attrs->out_layout, //
+ /*tgt_layout=*/"NCHW",
//
+ /*tensor_name=*/"output");
+
+ Optional<ShapeExpr> data_shape =
+ CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+ Optional<ShapeExpr> weight_shape =
+ CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout);
+
+ DataType out_dtype = attrs->out_dtype.is_void()
+ ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo,
weight_sinfo)
+ : attrs->out_dtype;
+ if (!data_shape.defined() || !weight_shape.defined()) {
+ return TensorStructInfo(out_dtype, out_layout.ndim());
+ }
+
+ Array<PrimExpr> data_NCHW_shape =
data2NCHW.ForwardShape(data_shape.value()->values);
+ Array<PrimExpr> weight_IOHW_shape =
weight2IOHW.ForwardShape(weight_shape.value()->values);
+
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ PrimExpr input_channel_data = data_NCHW_shape[1];
+ PrimExpr input_channel_kernel = weight_IOHW_shape[0];
+ if (analyzer->CanProve(input_channel_data != input_channel_kernel)) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "Conv2dTranspose expects the channel size of the data should equal
to the input channel "
+ "size of the weight. However, the data channel size is "
+ << input_channel_data << " while the weight input channel size is "
+ << input_channel_kernel);
+ } else if (!analyzer->CanProveEqual(input_channel_data,
input_channel_kernel)) {
+ // Todo(relax-team): Trust the input shape at this moment, and revisit
+ // this condition with runtime shape check
+ }
+ if (analyzer->CanProve(floormod(input_channel_kernel, attrs->groups) != 0)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Conv2dTranspose expects the number of input channels
to be divisible by "
+ "the number of groups. However, the number of input
channels is "
+ << input_channel_kernel << " while the number of groups
is " << attrs->groups);
+ } else if (!analyzer->CanProveEqual(floormod(input_channel_kernel,
attrs->groups), 0)) {
+ // Todo(relax-team): Trust the input shape at this moment, and revisit
+ // this condition with runtime shape check
+ }
+ if (analyzer->CanProve(attrs->output_padding[0]->value >=
attrs->strides[0]->value ||
+ attrs->output_padding[1]->value >=
attrs->strides[1]->value)) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Conv2dTranspose expects the output padding less than
the strides, but the "
+ "output padding is"
+ << attrs->output_padding << " while the strides are" <<
attrs->strides);
+ } else if (!analyzer->CanProve(attrs->output_padding[0]->value <
attrs->strides[0]->value &&
+ attrs->output_padding[1]->value <
attrs->strides[1]->value)) {
+ // Todo(relax-team): Trust the input padding at this moment, and revisit
+ // this condition with runtime shape check
+ }
+
+ PrimExpr input_h = data_NCHW_shape[2];
+ PrimExpr input_w = data_NCHW_shape[3];
+ PrimExpr kernel_h = weight_IOHW_shape[2];
+ PrimExpr kernel_w = weight_IOHW_shape[3];
+ PrimExpr padding_h = attrs->padding[0] + attrs->padding[2];
+ PrimExpr padding_w = attrs->padding[1] + attrs->padding[3];
+
+ std::vector<PrimExpr> out_NCHW_shape;
+ out_NCHW_shape.resize(4);
+ out_NCHW_shape[0] = data_NCHW_shape[0];
+ out_NCHW_shape[1] = weight_IOHW_shape[1] * attrs->groups;
+
+ PrimExpr out_h = (input_h - 1) * attrs->strides[0] - padding_h +
+ attrs->dilation[0] * (kernel_h - 1) +
attrs->output_padding[0] + 1;
+ PrimExpr out_w = (input_w - 1) * attrs->strides[1] - padding_w +
+ attrs->dilation[1] * (kernel_w - 1) +
attrs->output_padding[1] + 1;
+ out_NCHW_shape[2] = analyzer->Simplify(out_h);
+ out_NCHW_shape[3] = analyzer->Simplify(out_w);
+
+ Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
+}
+
+// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for
conv2d_transpose
+// and unit test for mixed_precision
+TVM_REGISTER_OP("relax.nn.conv2d_transpose")
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_attrs_type<Conv2DTransposeAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoConv2dTranspose);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h
index a65617b48d..7093c6a4d9 100644
--- a/src/relax/op/nn/convolution.h
+++ b/src/relax/op/nn/convolution.h
@@ -57,6 +57,17 @@ Expr conv2d(Expr data, Expr weight, Array<IntImm> strides,
Array<IntImm> padding
Array<IntImm> dilation, int groups, String data_layout, String
kernel_layout,
Optional<String> out_layout, DataType out_dtype);
+/*!
+ * \brief Two dimensional transposed convolution operator.
+ *
+ * This operator is intended to be the backward operator of conv2d. It can be
used to calculate the
+ * gradient of the result of conv2d w.r.t. the input of conv2d.
+ */
+Expr conv2d_transpose(Expr data, Expr weight, Array<IntImm> strides,
Array<IntImm> padding,
+ Array<IntImm> output_padding, Array<IntImm> dilation,
int groups,
+ String data_layout, String kernel_layout,
Optional<String> out_layout,
+ DataType out_dtype);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index a4c1e6b17d..61001ce678 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -25,11 +25,11 @@
namespace tvm {
namespace relax {
-/* relax.nn.max_pool2d */
-TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs);
+/* relax.nn.max_pool2d and relax.nn.avg_pool2d */
+TVM_REGISTER_NODE_TYPE(Pool2DAttrs);
-Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
- Array<IntImm> dilation, bool ceil_mode, String layout,
+Expr MakePool2d(String op_name, Expr data, Array<IntImm> pool_size,
Array<IntImm> strides,
+ Array<IntImm> padding, Array<IntImm> dilation, bool ceil_mode,
String layout,
Optional<String> out_layout) {
padding = GetCompletePadding2D(std::move(padding));
if (pool_size.size() == 1) {
@@ -51,24 +51,31 @@ Expr max_pool2d(Expr data, Array<IntImm> pool_size,
Array<IntImm> strides, Array
<< "The input dilation length is expected to be 2. However, the given
dilation is "
<< dilation;
- auto attrs = make_object<MaxPool2DAttrs>();
- attrs->pool_size = std::move(pool_size);
+ auto attrs = make_object<Pool2DAttrs>();
+ attrs->pool_size = ConvertIntImmToInt64(pool_size);
attrs->strides = ConvertIntImmToInt64(strides);
attrs->padding = ConvertIntImmToInt64(padding);
attrs->dilation = ConvertIntImmToInt64(dilation);
attrs->ceil_mode = ceil_mode;
attrs->layout = layout;
attrs->out_layout = out_layout.value_or(layout);
- static const Op& op = Op::Get("relax.nn.max_pool2d");
+ const Op& op = Op::Get(op_name);
return Call(op, {std::move(data)}, Attrs(attrs), {});
}
+Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
+ Array<IntImm> dilation, bool ceil_mode, String layout,
+ Optional<String> out_layout) {
+ return MakePool2d("relax.nn.max_pool2d", data, pool_size, strides, padding,
dilation, ceil_mode,
+ layout, out_layout);
+}
+
TVM_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d);
-StructInfo InferStructInfoMaxPool2D(const Call& call, const BlockBuilder& ctx)
{
+StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
- const auto* attrs = call->attrs.as<MaxPool2DAttrs>();
+ const auto* attrs = call->attrs.as<Pool2DAttrs>();
auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout,
//
/*tgt_layout=*/"NCHW",
//
/*tensor_name=*/"data");
@@ -113,8 +120,23 @@ StructInfo InferStructInfoMaxPool2D(const Call& call,
const BlockBuilder& ctx) {
TVM_REGISTER_OP("relax.nn.max_pool2d")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor")
- .set_attrs_type<MaxPool2DAttrs>()
- .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMaxPool2D);
+ .set_attrs_type<Pool2DAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D);
+
+Expr avg_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
+ Array<IntImm> dilation, bool ceil_mode, String layout,
+ Optional<String> out_layout) {
+ return MakePool2d("relax.nn.avg_pool2d", data, pool_size, strides, padding,
dilation, ceil_mode,
+ layout, out_layout);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool2d").set_body_typed(avg_pool2d);
+
+TVM_REGISTER_OP("relax.nn.avg_pool2d")
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor")
+ .set_attrs_type<Pool2DAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool2D);
/* relax.nn.adaptive_avg_pool2d */
TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs);
diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h
index 3c1792d21f..63d2e76772 100644
--- a/src/relax/op/nn/pooling.h
+++ b/src/relax/op/nn/pooling.h
@@ -36,6 +36,10 @@ namespace relax {
Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
Array<IntImm> dilation, bool ceil_mode, String layout,
Optional<String> out_layout);
+/*! \brief 2D average pooling operator. */
+Expr avg_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
+ Array<IntImm> dilation, bool ceil_mode, String layout,
Optional<String> out_layout);
+
/*! \brief 2D adaptive average pooling operator. */
Expr adaptive_avg_pool2d(Expr data, Optional<Array<IntImm>> output_size,
String layout,
Optional<String> out_layout);
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index e146a604af..c3a673f8fa 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -969,5 +969,131 @@ TVM_REGISTER_OP("relax.collapse_sum_to")
.add_argument("shape", "Shape", "The shape to collapse to.")
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoCollapseSumTo);
+/* relax.repeat */
+TVM_REGISTER_NODE_TYPE(RepeatAttrs);
+
+Expr repeat(Expr data, int repeats, Optional<Integer> axis) {
+ auto attrs = make_object<RepeatAttrs>();
+ attrs->repeats = std::move(repeats);
+ attrs->axis = std::move(axis);
+
+ static const Op& op = Op::Get("relax.repeat");
+ return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.repeat").set_body_typed(repeat);
+
+StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) {
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ const auto* attrs = call->attrs.as<RepeatAttrs>();
+ const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+
+ if (attrs->axis.defined() && !data_sinfo->IsUnknownNdim()) {
+ int axis = attrs->axis.value()->value;
+ int ndim = data_sinfo->ndim;
+ if (axis < -ndim || axis >= ndim) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "Repeat requires the input axis belongs range "
+ "[-data.struct_info.ndim, data.struct_info.ndim - 1]. However,
the input axis is "
+ << axis << ", while ndim is " << ndim);
+ }
+ }
+
+ if (data_shape == nullptr) {
+ if (attrs->axis.defined()) {
+ if (analyzer->CanProveEqual(attrs->repeats, 1)) {
+ // the shape does not changes
+ return data_sinfo;
+ } else {
+ return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
+ }
+ } else {
+ return TensorStructInfo(data_sinfo->dtype, 1);
+ }
+ }
+
+ if (!attrs->axis.defined()) {
+ PrimExpr new_shape =
+ analyzer->Simplify(ComputeShapeProduct(data_shape->values) *
attrs->repeats);
+ return TensorStructInfo(ShapeExpr(Array<PrimExpr>({new_shape})),
data_sinfo->dtype);
+ }
+
+ int axis = NormalizeAxis(call, ctx, data_sinfo->ndim,
attrs->axis.value()->value);
+ auto shape_array = data_shape->values;
+ shape_array.Set(axis, analyzer->Simplify(shape_array[axis] *
attrs->repeats));
+ return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype);
+}
+
+// TODO(relax-team): implement FRelaxInferLayout for repeat
+TVM_REGISTER_OP("relax.repeat")
+ .set_attrs_type<RepeatAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoRepeat);
+
+/* relax.tile */
+TVM_REGISTER_NODE_TYPE(TileAttrs);
+
+Expr tile(Expr data, Array<Integer> repeats) {
+ auto attrs = make_object<TileAttrs>();
+ attrs->repeats = std::move(repeats);
+
+ static const Op& op = Op::Get("relax.tile");
+ return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.tile").set_body_typed(tile);
+
+StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) {
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ const auto* attrs = call->attrs.as<TileAttrs>();
+ const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+ int l = attrs->repeats.size();
+ int ndim = data_sinfo->ndim;
+
+ if (data_shape == nullptr) {
+ if (data_sinfo->IsUnknownNdim()) {
+ return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
+ }
+ if (l > ndim) {
+ return TensorStructInfo(data_sinfo->dtype, l);
+ } else {
+ for (auto i : attrs->repeats) {
+ if (!analyzer->CanProveEqual(i, 1)) {
+ return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim);
+ }
+ }
+ // if control reaches here, the shape should not be changed
+ return data_sinfo;
+ }
+ }
+
+ int out_ndim = std::max(l, ndim);
+ int l_delta = out_ndim - l;
+ int ndim_delta = out_ndim - ndim;
+ Array<PrimExpr> out_shape;
+ for (int i = 0; i < out_ndim; ++i) {
+ if (i < l_delta) {
+ out_shape.push_back(data_shape->values[i - ndim_delta]);
+ } else if (i < ndim_delta) {
+ out_shape.push_back(attrs->repeats[i - l_delta]);
+ } else {
+ out_shape.push_back(
+ analyzer->Simplify(data_shape->values[i - ndim_delta] *
attrs->repeats[i - l_delta]));
+ }
+ }
+ return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
+}
+
+// TODO(relax-team): implement FRelaxInferLayout for tile
+TVM_REGISTER_OP("relax.tile")
+ .set_attrs_type<TileAttrs>()
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTile);
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 95e29a3dce..fb75664a1d 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -133,6 +133,33 @@ Expr collapse_sum_like(Expr data, Expr collapse_target);
*/
Expr collapse_sum_to(Expr data, Expr shape);
+/*!
+ * \brief Repeats elements of an array.
+ * \param data The input tensor.
+ * \param repeats The number of repetitions.
+ * \param axis The axis along which to repeat values. The negative numbers are
interpreted counting
+ * from the backward. By default, use the flattened input array, and return a
flat output array.
+ * \return The computed result.
+ */
+Expr repeat(Expr data, int repeats, Optional<Integer> axis = NullOpt);
+
+/*!
+ * \brief Construct an array by repeating data the number of times given by
reps.
+ *
+ * If reps has length l, and data has dimension d, the result will have
dimension of max(l, d).
+ *
+ * If d < l, data is promoted to be l-dimensional by prepending new axes. So a
shape (3,) Tensor is
+ * promoted to (1, 3) for 2-D replication, or shape (1, 1, 3) for 3-D
replication. If this is not
+ * the desired behavior, promote data to d-dimensions manually before calling
this function.
+ *
+ * If d > l, reps is promoted to length d by pre-pending 1's to it. Thus for a
data of shape
+ * (2, 3, 4, 5), a reps of (2, 2) is treated as (1, 1, 2, 2).
+ * \param data The input tensor.
+ * \param repeats The number of repetitions of data along each axis.
+ * \return The computed result.
+ */
+Expr tile(Expr data, Array<Integer> repeats);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_op_manipulate.py
b/tests/python/relax/test_op_manipulate.py
index abb414b472..e1f550cc38 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -21,6 +21,7 @@ from tvm import relax, tir
from tvm import TVMError
from tvm.ir import Op
from tvm.script import relax as R
+from tvm.tir.expr import FloatImm, IntImm
def test_op_correctness():
@@ -32,6 +33,8 @@ def test_op_correctness():
assert relax.op.permute_dims(x).op == Op.get("relax.permute_dims")
assert relax.op.reshape(x, (4, 5, 3)).op == Op.get("relax.reshape")
assert relax.op.split(x, indices_or_sections=1).op == Op.get("relax.split")
+ assert relax.op.tile(x, (2, 2, 2)).op == Op.get("relax.tile")
+ assert relax.op.repeat(x, 2, 0).op == Op.get("relax.repeat")
assert relax.op.squeeze(x).op == Op.get("relax.squeeze")
assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c,
a)).op == Op.get(
"relax.layout_transform"
@@ -2704,5 +2707,248 @@ def
test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var():
)
+def test_repeat_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor("float32"))
+ x3 = relax.Var("x", R.Tensor((2, 10, 4)))
+ x4 = relax.Var("x", R.Tensor(ndim=3))
+ x5 = relax.Var("x", R.Tensor())
+
+ _check_inference(
+ bb,
+ relax.op.repeat(x0, 2, axis=0),
+ relax.TensorStructInfo((4, 10, 4), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.repeat(x0, 2, axis=-2),
+ relax.TensorStructInfo((2, 20, 4), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.repeat(x0, 2),
+ relax.TensorStructInfo((160,), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.repeat(x1, 2, axis=0),
+ relax.TensorStructInfo(dtype="float32", ndim=3),
+ )
+ _check_inference(
+ bb,
+ relax.op.repeat(x1, 2),
+ relax.TensorStructInfo(dtype="float32", ndim=1),
+ )
+ _check_inference(bb, relax.op.repeat(x2, 2, axis=0),
relax.TensorStructInfo(dtype="float32"))
+ _check_inference(bb, relax.op.repeat(x2, 2),
relax.TensorStructInfo(dtype="float32", ndim=1))
+ _check_inference(
+ bb,
+ relax.op.repeat(x3, 2, axis=0),
+ relax.TensorStructInfo((4, 10, 4), dtype=""),
+ )
+ _check_inference(bb, relax.op.repeat(x4, 2, axis=0),
relax.TensorStructInfo(dtype="", ndim=3))
+ _check_inference(bb, relax.op.repeat(x5, 2, axis=0),
relax.TensorStructInfo(dtype=""))
+
+
+def test_repeat_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ a = tir.Var("a", "int64")
+ b = tir.Var("b", "int64")
+ c = tir.Var("c", "int64")
+ x = relax.Var("x", R.Tensor((a, b, c), "float32"))
+
+ _check_inference(bb, relax.op.repeat(x, 2, 0), relax.TensorStructInfo((a *
2, b, c), "float32"))
+ _check_inference(
+ bb,
+ relax.op.repeat(x, 2, -1),
+ relax.TensorStructInfo((a, b, c * 2), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.repeat(x, 2),
+ relax.TensorStructInfo((a * b * c * 2,), "float32"),
+ )
+
+
+def test_repeat_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8"))
+
+ _check_inference(bb, relax.op.repeat(x0, 2, 0), relax.TensorStructInfo((4,
3, 4), "float16"))
+ _check_inference(bb, relax.op.repeat(x1, 2, 0), relax.TensorStructInfo((4,
3, 4), "int8"))
+
+
+def test_repeat_infer_struct_info_axis_out_of_range():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor("float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.repeat(x0, 2, 3))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.repeat(x0, 2, -4))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.repeat(x1, 2, 3))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.repeat(x1, 2, -4))
+ # okay
+ bb.normalize(relax.op.repeat(x2, 2, 3))
+ bb.normalize(relax.op.repeat(x2, 2, -4))
+
+
+def test_repeat_return_data_sinfo():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor("float32"))
+
+ _check_inference(bb, relax.op.repeat(x0, 1, 0), x0.struct_info)
+ _check_inference(bb, relax.op.repeat(x0, 1, -1), x0.struct_info)
+ _check_inference(bb, relax.op.repeat(x1, 1, 0), x1.struct_info)
+ _check_inference(bb, relax.op.repeat(x2, 1, 0), x2.struct_info)
+
+
+def test_repeat_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5)))
+ x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5),
"float32")))
+ x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+ r1 = tir.Var("r", "float32")
+ r2 = tir.StringImm("abc")
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.repeat(x0, 2))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.repeat(x1, 2))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.repeat(x2, 1.5))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.repeat(x2, r1))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.repeat(x2, r2))
+
+
+def test_tile_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor("float32"))
+ x3 = relax.Var("x", R.Tensor((2, 10, 4)))
+ x4 = relax.Var("x", R.Tensor(ndim=3))
+ x5 = relax.Var("x", R.Tensor())
+
+ _check_inference(
+ bb,
+ relax.op.tile(x0, 2),
+ relax.TensorStructInfo((2, 10, 8), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.tile(x0, (3, 2)),
+ relax.TensorStructInfo((2, 30, 8), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.tile(x0, (4, 3, 2)),
+ relax.TensorStructInfo((8, 30, 8), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.tile(x0, (5, 4, 3, 2)),
+ relax.TensorStructInfo((5, 8, 30, 8), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.tile(x1, 2),
+ relax.TensorStructInfo(dtype="float32", ndim=3),
+ )
+ _check_inference(
+ bb,
+ relax.op.tile(x1, (5, 4, 3, 2)),
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ )
+ _check_inference(bb, relax.op.tile(x2, (5, 4, 3, 2)),
relax.TensorStructInfo(dtype="float32"))
+ _check_inference(
+ bb,
+ relax.op.tile(x3, 2),
+ relax.TensorStructInfo((2, 10, 8), dtype=""),
+ )
+ _check_inference(
+ bb,
+ relax.op.tile(x3, (5, 4, 3, 2)),
+ relax.TensorStructInfo((5, 8, 30, 8), dtype=""),
+ )
+ _check_inference(bb, relax.op.tile(x4, 2),
relax.TensorStructInfo(dtype="", ndim=3))
+ _check_inference(bb, relax.op.tile(x4, (5, 4, 3, 2)),
relax.TensorStructInfo(dtype="", ndim=4))
+ _check_inference(bb, relax.op.tile(x5, (5, 4, 3, 2)),
relax.TensorStructInfo(dtype=""))
+
+
+def test_tile_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ a = tir.Var("a", "int64")
+ b = tir.Var("b", "int64")
+ c = tir.Var("c", "int64")
+ x = relax.Var("x", R.Tensor((a, b, c), "float32"))
+
+ _check_inference(bb, relax.op.tile(x, 2), relax.TensorStructInfo((a, b, c
* 2), "float32"))
+ _check_inference(
+ bb, relax.op.tile(x, (3, 2)), relax.TensorStructInfo((a, b * 3, c *
2), "float32")
+ )
+ _check_inference(
+ bb, relax.op.tile(x, (4, 3, 2)), relax.TensorStructInfo((a * 4, b * 3,
c * 2), "float32")
+ )
+ _check_inference(
+ bb,
+ relax.op.tile(x, (5, 4, 3, 2)),
+ relax.TensorStructInfo((5, a * 4, b * 3, c * 2), "float32"),
+ )
+
+
+def test_tile_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8"))
+
+ _check_inference(bb, relax.op.tile(x0, (3, 2)), relax.TensorStructInfo((2,
9, 8), "float16"))
+ _check_inference(bb, relax.op.tile(x1, (3, 2)), relax.TensorStructInfo((2,
9, 8), "int8"))
+
+
+def test_tile_return_data_sinfo():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 10, 4), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ x2 = relax.Var("x", R.Tensor("float32"))
+
+ _check_inference(bb, relax.op.tile(x0, 1), x0.struct_info)
+ _check_inference(bb, relax.op.tile(x0, (1, 1)), x0.struct_info)
+ _check_inference(bb, relax.op.tile(x0, (1, 1, 1)), x0.struct_info)
+ _check_inference(bb, relax.op.tile(x1, 1), x1.struct_info)
+ _check_inference(bb, relax.op.tile(x2, 1), x2.struct_info)
+
+
+def test_tile_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5)))
+ x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5),
"float32")))
+ x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+ r1 = tir.Var("a", "float32")
+ r2 = tir.StringImm("abc")
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.tile(x0, 2))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.tile(x1, 2))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.tile(x2, (2, 1.5, 2)))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.tile(x2, (2, r1)))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.tile(x2, r2))
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_op_nn_convolution.py
b/tests/python/relax/test_op_nn_convolution.py
index 6533d43420..334f6977f7 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -27,6 +27,7 @@ def test_op_correctness():
x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
assert relax.op.nn.conv2d(x, w).op == Op.get("relax.nn.conv2d")
+ assert relax.op.nn.conv2d_transpose(x, w).op ==
Op.get("relax.nn.conv2d_transpose")
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
@@ -425,5 +426,389 @@ def test_conv2d_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.nn.conv2d(x1, w0))
+def test_conv2d_transpose_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 28, 28, 3), "float32"))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=4))
+ x3 = relax.Var("x", R.Tensor("float32"))
+ x4 = relax.Var("x", R.Tensor())
+ x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 16), "float32"))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+ w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
+ w2 = relax.Var("w", R.Tensor("float32", ndim=4))
+ w3 = relax.Var("w", R.Tensor("float32"))
+ w4 = relax.Var("w", R.Tensor((4, 48, 3, 3, 16), "float32"))
+
+ _check_inference(
+ bb, relax.op.nn.conv2d_transpose(x0, w0), relax.TensorStructInfo((2,
4, 30, 30), "float32")
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float16"),
+ relax.TensorStructInfo((2, 4, 30, 30), "float16"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, padding=1),
+ relax.TensorStructInfo((2, 4, 28, 28), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, padding=[1, 2]),
+ relax.TensorStructInfo((2, 4, 28, 26), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, padding=[1, 2, 3, 4]),
+ relax.TensorStructInfo((2, 4, 26, 24), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, strides=3, output_padding=1),
+ relax.TensorStructInfo((2, 4, 85, 85), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, strides=3, output_padding=[2, 1]),
+ relax.TensorStructInfo((2, 4, 86, 85), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, strides=2),
+ relax.TensorStructInfo((2, 4, 57, 57), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, strides=(2, 3)),
+ relax.TensorStructInfo((2, 4, 57, 84), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, dilation=2),
+ relax.TensorStructInfo((2, 4, 32, 32), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, dilation=(2, 1)),
+ relax.TensorStructInfo((2, 4, 32, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x1, w0, data_layout="NHWC"),
+ relax.TensorStructInfo((2, 30, 30, 4), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, out_layout="NHWC"),
+ relax.TensorStructInfo((2, 30, 30, 4), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w1, kernel_layout="OIHW"),
+ relax.TensorStructInfo((2, 4, 30, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(
+ x5, w4, data_layout="NCHW16c", kernel_layout="IOHW16i",
out_layout="NHWC16c"
+ ),
+ relax.TensorStructInfo((2, 30, 30, 3, 16), "float32"),
+ )
+ _check_inference(
+ bb, relax.op.nn.conv2d_transpose(x2, w0),
relax.TensorStructInfo(dtype="float32", ndim=4)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv2d_transpose(x3, w0),
relax.TensorStructInfo(dtype="float32", ndim=4)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv2d_transpose(x0, w2),
relax.TensorStructInfo(dtype="float32", ndim=4)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv2d_transpose(x0, w3),
relax.TensorStructInfo(dtype="float32", ndim=4)
+ )
+ _check_inference(
+ bb, relax.op.nn.conv2d_transpose(x4, w0),
relax.TensorStructInfo(dtype="", ndim=4)
+ )
+
+
+def test_conv2d_transpose_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ c = tir.Var("c", "int64")
+ c16 = tir.Var("c16", "int64")
+ ih = tir.Var("ih", "int64")
+ iw = tir.Var("iw", "int64")
+ ki = tir.Var("ki", "int64")
+ ko = tir.Var("ko", "int64")
+ kh = tir.Var("kh", "int64")
+ kw = tir.Var("kw", "int64")
+ x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32"))
+ x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32"))
+ w0 = relax.Var("w", R.Tensor((ki, ko, kh, kw), "float32"))
+ w1 = relax.Var("w", R.Tensor((c, ko, kh, kw), "float32"))
+ w2 = relax.Var("w", R.Tensor((c, ko, kh, kw, c16), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0),
+ relax.TensorStructInfo((n, ko, ih + kh - 1, iw + kw - 1), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w1),
+ relax.TensorStructInfo((n, ko, ih + kh - 1, iw + kw - 1), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(
+ x1, w2, data_layout="NCHW16c", kernel_layout="IOHW16i",
out_layout="NCHW"
+ ),
+ relax.TensorStructInfo((n, ko, ih + kh - 1, iw + kw - 1), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(
+ x0, w0, strides=(2, 2), padding=(1, 1), output_padding=(1, 0),
dilation=(2, 2)
+ ),
+ relax.TensorStructInfo(
+ (n, ko, ih * 2 + kh * 2 - 4, iw * 2 + kw * 2 - 5),
+ "float32",
+ ),
+ )
+
+
+def test_conv2d_transpose_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+ s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+ s2 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+ s3 = relax.Var("s", relax.ShapeStructInfo())
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+ x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32"))
+ w = relax.Var("w", relax.TensorStructInfo(s2, "float32"))
+
+ _check_inference(
+ bb, relax.op.nn.conv2d_transpose(x0, w),
relax.TensorStructInfo(dtype="float32", ndim=4)
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x1, w, data_layout="NCHW16c"),
+ relax.TensorStructInfo(dtype="float32", ndim=5),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w, out_layout="NCHW16c"),
+ relax.TensorStructInfo(dtype="float32", ndim=5),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x2, w),
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ )
+
+
+def test_conv2d_transpose_infer_struct_info_groups():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 8, 28, 28, 16), "float32"))
+ w0 = relax.Var("w", R.Tensor((128, 6, 3, 3), "float32"))
+ w1 = relax.Var("w", R.Tensor((16, 6, 3, 3, 8), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, groups=8),
+ relax.TensorStructInfo((2, 48, 30, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w1, kernel_layout="IOHW8i", groups=8),
+ relax.TensorStructInfo((2, 48, 30, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x1, w0, data_layout="NCHW16c", groups=8),
+ relax.TensorStructInfo((2, 3, 30, 30, 16), "float32"),
+ )
+
+
+def test_conv2d_transpose_infer_struct_info_symbolic_groups():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ ic = tir.Var("c", "int64")
+ oc = tir.Var("oc", "int64")
+ x = relax.Var("x", R.Tensor((n, ic * 4, 28, 28), "float32"))
+ w0 = relax.Var("w", R.Tensor((ic, oc, 3, 3), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x, w0, groups=4),
+ relax.TensorStructInfo((n, oc * 4, 30, 30), "float32"),
+ )
+
+
+def test_conv2d_transpose_infer_struct_info_input_channel_group_incompatible():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ ic = tir.Var("c", "int64")
+ oc = tir.Var("oc", "int64")
+ x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32"))
+ w0 = relax.Var("w", R.Tensor((128, 20, 3, 3), "float32"))
+ x1 = relax.Var("x", R.Tensor((n, ic, 28, 28), "float32"))
+ w1 = relax.Var("w", R.Tensor((ic - 1, oc, 3, 3), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x0, w0, groups=6))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x1, w1, groups=6))
+
+
+def test_conv2d_transpose_non_positive_group():
+ x = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32"))
+ w = relax.Var("w", R.Tensor((128, 16, 3, 3), "float32"))
+
+ with pytest.raises(TVMError):
+ relax.op.nn.conv2d_transpose(x, w, groups=0)
+ with pytest.raises(TVMError):
+ relax.op.nn.conv2d_transpose(x, w, groups=-2)
+
+
+def test_conv2d_transpose_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16"))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float64"))
+ w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float64"))
+ x2 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8"))
+ w2 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8"))
+ x3 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int32"))
+ w3 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int32"))
+
+ _check_inference(
+ bb, relax.op.nn.conv2d_transpose(x0, w0), relax.TensorStructInfo((2,
4, 30, 30), "float16")
+ )
+ _check_inference(
+ bb, relax.op.nn.conv2d_transpose(x1, w1), relax.TensorStructInfo((2,
4, 30, 30), "float64")
+ )
+ _check_inference(
+ bb, relax.op.nn.conv2d_transpose(x2, w2), relax.TensorStructInfo((2,
4, 30, 30), "int8")
+ )
+ _check_inference(
+ bb, relax.op.nn.conv2d_transpose(x3, w3), relax.TensorStructInfo((2,
4, 30, 30), "int32")
+ )
+
+
+def test_conv2d_transpose_unequal_input_channel():
+ bb = relax.BlockBuilder()
+ ic = tir.Var("ic", "int64")
+ x0 = relax.Var("x", R.Tensor([2, 3, 28, 28], "float32"))
+ w0 = relax.Var("w", R.Tensor([4, 3, 3, 3], "float32"))
+ x1 = relax.Var("x", R.Tensor([2, ic, 28, 28], "float32"))
+ w1 = relax.Var("w", R.Tensor([ic + 2, 4, 3, 3], "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x0, w0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x1, w1))
+
+
+def test_conv2d_transpose_wrong_output_padding():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor([2, 3, 28, 28], "float32"))
+ w0 = relax.Var("w", R.Tensor([3, 4, 3, 3], "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x0, w0, strides=2,
output_padding=2))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x0, w0, strides=(2, 2),
output_padding=(2, 2)))
+
+
+def test_conv2d_transpose_stride_padding_dilation_int64():
+ x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+ w = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+ conv2d_transpose = relax.op.nn.conv2d_transpose(
+ x, w, strides=(1, 1), padding=(1, 1), output_padding=(1, 2),
dilation=(1, 1)
+ )
+
+ assert conv2d_transpose.attrs.strides[0].dtype == "int64"
+ assert conv2d_transpose.attrs.strides[1].dtype == "int64"
+ assert conv2d_transpose.attrs.padding[0].dtype == "int64"
+ assert conv2d_transpose.attrs.padding[1].dtype == "int64"
+ assert conv2d_transpose.attrs.padding[2].dtype == "int64"
+ assert conv2d_transpose.attrs.padding[3].dtype == "int64"
+ assert conv2d_transpose.attrs.output_padding[0].dtype == "int64"
+ assert conv2d_transpose.attrs.output_padding[1].dtype == "int64"
+ assert conv2d_transpose.attrs.dilation[0].dtype == "int64"
+ assert conv2d_transpose.attrs.dilation[1].dtype == "int64"
+
+
+def test_conv2d_transpose_wrong_strides_padding_dilation_length():
+ x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+ w = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+ with pytest.raises(TVMError):
+ relax.op.nn.conv2d_transpose(x, w, strides=(1, 2, 3))
+ with pytest.raises(TVMError):
+ relax.op.nn.conv2d_transpose(x, w, padding=(1, 2, 3))
+ with pytest.raises(TVMError):
+ relax.op.nn.conv2d_transpose(x, w, output_padding=(1, 2, 3))
+ with pytest.raises(TVMError):
+ relax.op.nn.conv2d_transpose(x, w, dilation=(1, 2, 3))
+
+
+def test_conv2d_transpose_infer_struct_info_wrong_layout_string():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+ w = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x, w, data_layout="IOHW"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x, w, kernel_layout="NHWC"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x, w, out_layout="OHWI"))
+
+
+def test_conv2d_transpose_dtype_mismatch():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+ w = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x, w))
+
+
+def test_conv2d_transpose_wrong_input_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32"))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=3))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+ w1 = relax.Var("w", R.Tensor((3, 4, 6, 3, 3), "float32"))
+ w2 = relax.Var("w", R.Tensor("float32", ndim=6))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x0, w1))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x0, w1,
data_layout="NCHW16c"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x0, w2))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x1, w0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x2, w0))
+
+
+def test_conv2d_transpose_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+ x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28)))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+ w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((3, 4, 3, 3),
"float32")))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x0, w1))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.conv2d_transpose(x1, w0))
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_op_nn_pooling.py
b/tests/python/relax/test_op_nn_pooling.py
index 0eec5de21c..2bd7747f31 100644
--- a/tests/python/relax/test_op_nn_pooling.py
+++ b/tests/python/relax/test_op_nn_pooling.py
@@ -26,6 +26,7 @@ from tvm.script import relax as R
def test_op_correctness():
x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
assert relax.op.nn.max_pool2d(x).op == Op.get("relax.nn.max_pool2d")
+ assert relax.op.nn.avg_pool2d(x).op == Op.get("relax.nn.avg_pool2d")
assert relax.op.nn.adaptive_avg_pool2d(x).op ==
Op.get("relax.nn.adaptive_avg_pool2d")
@@ -203,7 +204,7 @@ def test_max_pool2d_infer_struct_info_more_input_dtype():
)
-def test_conv2d_stride_padding_dilation_int64():
+def test_max_pool2d_stride_padding_dilation_int64():
x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
max_pool2d = relax.op.nn.max_pool2d(x, (3, 3), strides=(1, 1), padding=(1,
1), dilation=(1, 1))
@@ -259,6 +260,231 @@ def test_max_pool2d_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.nn.max_pool2d(x1))
+def test_avg_pool2d_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=4))
+ x3 = relax.Var("x", R.Tensor("float32"))
+ x4 = relax.Var("x", R.Tensor(ndim=4))
+ x5 = relax.Var("x", R.Tensor())
+ x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
+
+ _check_inference(
+ bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32),
"float32")
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x0, pool_size=3),
+ relax.TensorStructInfo((2, 3, 30, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x0, pool_size=(5, 3)),
+ relax.TensorStructInfo((2, 3, 28, 30), "float32"),
+ )
+ _check_inference(
+ bb, relax.op.nn.avg_pool2d(x0, padding=1), relax.TensorStructInfo((2,
3, 34, 34), "float32")
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x0, padding=[1, 2]),
+ relax.TensorStructInfo((2, 3, 34, 36), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x0, strides=2),
+ relax.TensorStructInfo((2, 3, 16, 16), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x0, dilation=2),
+ relax.TensorStructInfo((2, 3, 32, 32), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x1, layout="NHWC"),
+ relax.TensorStructInfo((2, 32, 32, 3), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x0, out_layout="NHWC"),
+ relax.TensorStructInfo((2, 32, 32, 3), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"),
+ relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"),
+ )
+ _check_inference(
+ bb, relax.op.nn.avg_pool2d(x2),
relax.TensorStructInfo(dtype="float32", ndim=4)
+ )
+ _check_inference(
+ bb, relax.op.nn.avg_pool2d(x3),
relax.TensorStructInfo(dtype="float32", ndim=4)
+ )
+ _check_inference(bb, relax.op.nn.avg_pool2d(x4),
relax.TensorStructInfo(dtype="", ndim=4))
+ _check_inference(bb, relax.op.nn.avg_pool2d(x5),
relax.TensorStructInfo(dtype="", ndim=4))
+
+
+def test_avg_pool2d_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ c = tir.Var("c", "int64")
+ c16 = tir.Var("c16", "int64")
+ ih = tir.Var("ih", "int64")
+ iw = tir.Var("iw", "int64")
+ x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32"))
+ x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(
+ x0, pool_size=(3, 3), strides=(3, 3), padding=(2, 2), dilation=(2,
2)
+ ),
+ relax.TensorStructInfo(
+ (
+ n,
+ c,
+ tvm.tir.floordiv(ih - 1, 3) + 1,
+ tvm.tir.floordiv(iw - 1, 3) + 1,
+ ),
+ "float32",
+ ),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x1, layout="NCHW16c", out_layout="NHWC"),
+ relax.TensorStructInfo((n, ih, iw, c * 16), "float32"),
+ )
+
+
+def test_avg_pool2d_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+ s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+ s2 = relax.Var("s", relax.ShapeStructInfo())
+ x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+ x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+ x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+
+ _check_inference(
+ bb, relax.op.nn.avg_pool2d(x0),
relax.TensorStructInfo(dtype="float32", ndim=4)
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x1, layout="NCHW16c"),
+ relax.TensorStructInfo(dtype="float32", ndim=5),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x2),
+ relax.TensorStructInfo(dtype="float32", ndim=4),
+ )
+
+
+def test_avg_pool2d_infer_struct_info_ceil_mode():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x, pool_size=3, strides=2, ceil_mode=True),
+ relax.TensorStructInfo((2, 3, 16, 16), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True),
+ relax.TensorStructInfo((2, 3, 15, 16), "float32"),
+ )
+
+
+def test_avg_pool2d_infer_struct_info_ceil_mode_symbolic():
+ bb = relax.BlockBuilder()
+ n = tir.Var("n", "int64")
+ c = tir.Var("c", "int64")
+ ih = tir.Var("ih", "int64")
+ iw = tir.Var("iw", "int64")
+ x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.avg_pool2d(
+ x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2,
2), ceil_mode=True
+ ),
+ relax.TensorStructInfo((n, c, tvm.tir.floordiv(ih, 2),
tvm.tir.floordiv(iw, 2)), "float32"),
+ )
+
+
+def test_avg_pool2d_infer_struct_info_more_input_dtype():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8"))
+ x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64"))
+ _check_inference(
+ bb, relax.op.nn.avg_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32),
"float16")
+ )
+ _check_inference(bb, relax.op.nn.avg_pool2d(x1),
relax.TensorStructInfo((2, 3, 32, 32), "int8"))
+ _check_inference(
+ bb, relax.op.nn.avg_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32),
"int64")
+ )
+
+
+def test_avg_pool2d_stride_padding_dilation_int64():
+ x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+ avg_pool2d = relax.op.nn.avg_pool2d(x, (3, 3), strides=(1, 1), padding=(1,
1), dilation=(1, 1))
+
+ assert avg_pool2d.attrs.strides[0].dtype == "int64"
+ assert avg_pool2d.attrs.strides[1].dtype == "int64"
+ assert avg_pool2d.attrs.padding[0].dtype == "int64"
+ assert avg_pool2d.attrs.padding[1].dtype == "int64"
+ assert avg_pool2d.attrs.padding[2].dtype == "int64"
+ assert avg_pool2d.attrs.padding[3].dtype == "int64"
+ assert avg_pool2d.attrs.dilation[0].dtype == "int64"
+ assert avg_pool2d.attrs.dilation[1].dtype == "int64"
+
+
+def test_avg_pool2d_wrong_pool_size_strides_padding_dilation_length():
+ x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+ with pytest.raises(TVMError):
+ relax.op.nn.avg_pool2d(x, pool_size=(1, 2, 3))
+ with pytest.raises(TVMError):
+ relax.op.nn.avg_pool2d(x, strides=(1, 2, 3))
+ with pytest.raises(TVMError):
+ relax.op.nn.avg_pool2d(x, padding=(1, 2, 3))
+ with pytest.raises(TVMError):
+ relax.op.nn.avg_pool2d(x, dilation=(1, 2, 3))
+
+
+def test_avg_pool2d_infer_struct_info_wrong_layout_string():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.avg_pool2d(x, layout="OIHW"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.avg_pool2d(x, out_layout="OHWI"))
+
+
+def test_avg_pool2d_wrong_input_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.avg_pool2d(x0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.avg_pool2d(x1))
+
+
+def test_avg_pool2d_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28)))
+ x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28),
"float32")))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.avg_pool2d(x0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.avg_pool2d(x1))
+
+
def test_adaptive_avg_pool2d_infer_struct_info():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 7ae0eb359a..fbb1024d26 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -18,7 +18,7 @@
import pytest
import tvm
from tvm.relax.transform import LegalizeOps
-from tvm.script import relax as R, tir as T
+from tvm.script import relax as R, tir as T, ir as I
import tvm.testing
@@ -888,5 +888,202 @@ def test_collapse_sum_to_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_repeat():
+ # fmt: off
+ @I.ir_module
+ class Repeat:
+ @R.function
+ def main(x: R.Tensor((3, 2, 3), "float32")):
+ gv = R.repeat(x, 2, 0)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((6, 2,
3), dtype="float32"):
+ gv = R.call_tir(repeat, (x,), out_sinfo=R.Tensor((6, 2, 3),
dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def repeat(rxplaceholder: T.Buffer((T.int64(3), T.int64(2),
T.int64(3)), "float32"), T_repeat: T.Buffer((T.int64(6), T.int64(2),
T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1, ax2 in T.grid(T.int64(6), T.int64(2), T.int64(3)):
+ with T.block("T_repeat"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(rxplaceholder[v_ax0 // T.int64(2), v_ax1, v_ax2])
+ T.writes(T_repeat[v_ax0, v_ax1, v_ax2])
+ T_repeat[v_ax0, v_ax1, v_ax2] = rxplaceholder[v_ax0 //
T.int64(2), v_ax1, v_ax2]
+ # fmt: on
+
+ mod = LegalizeOps()(Repeat)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_repeat_no_axis():
+ # fmt: off
+ @I.ir_module
+ class Repeat:
+ @R.function
+ def main(x: R.Tensor((3, 2, 3), "float32")):
+ gv = R.repeat(x, 2)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((3, 2, 3), dtype="float32")
+ ) -> R.Tensor((36,), dtype="float32"):
+ gv = R.call_tir(repeat, (x,), out_sinfo=R.Tensor((36,),
dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def repeat(
+ rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)),
"float32"),
+ T_repeat: T.Buffer((T.int64(36),), "float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ T_reshape = T.alloc_buffer((T.int64(18),))
+ for ax0 in range(T.int64(18)):
+ with T.block("T_reshape"):
+ v_ax0 = T.axis.spatial(T.int64(18), ax0)
+ T.reads(
+ rxplaceholder[
+ v_ax0 % T.int64(18) // T.int64(6),
+ v_ax0 % T.int64(6) // T.int64(3),
+ v_ax0 % T.int64(3),
+ ]
+ )
+ T.writes(T_reshape[v_ax0])
+ T_reshape[v_ax0] = rxplaceholder[
+ v_ax0 % T.int64(18) // T.int64(6),
+ v_ax0 % T.int64(6) // T.int64(3),
+ v_ax0 % T.int64(3),
+ ]
+ for ax0 in range(T.int64(36)):
+ with T.block("T_repeat"):
+ v_ax0 = T.axis.spatial(T.int64(36), ax0)
+ T.reads(T_reshape[v_ax0 // T.int64(2)])
+ T.writes(T_repeat[v_ax0])
+ T_repeat[v_ax0] = T_reshape[v_ax0 // T.int64(2)]
+ # fmt: on
+
+ mod = LegalizeOps()(Repeat)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_repeat_symbolic():
+ # fmt: off
+ @I.ir_module
+ class Repeat:
+ @R.function
+ def main(x: R.Tensor(("a", "b", "c"), "float32")):
+ gv = R.repeat(x, 2, 0)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def repeat(var_rxplaceholder: T.handle, var_T_repeat: T.handle):
+ T.func_attr({"tir.noalias": True})
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c))
+ T_repeat = T.match_buffer(var_T_repeat, (T.int64(2) * a, b, c))
+ # with T.block("root"):
+ for ax0, ax1, ax2 in T.grid(a * T.int64(2), b, c):
+ with T.block("T_repeat"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(rxplaceholder[v_ax0 // T.int64(2), v_ax1, v_ax2])
+ T.writes(T_repeat[v_ax0, v_ax1, v_ax2])
+ T_repeat[v_ax0, v_ax1, v_ax2] = rxplaceholder[v_ax0 //
T.int64(2), v_ax1, v_ax2]
+
+ @R.function
+ def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) ->
R.Tensor(("2 * a", "b", "c"), dtype="float32"):
+ a = T.Var("a", "int64")
+ b = T.Var("b", "int64")
+ c = T.Var("c", "int64")
+ gv = R.call_tir(repeat, (x,), out_sinfo=R.Tensor((2 * a, b, c),
dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Repeat)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_tile():
+ # fmt: off
+ @I.ir_module
+ class Tile:
+ @R.function
+ def main(x: R.Tensor((3, 2, 3), "float32")):
+ gv = R.tile(x, (2, 1, 2, 3))
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tile(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)),
"float32"), T_tile: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(9)),
"float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(9)):
+ with T.block("T_tile"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[v_ax1 % T.int64(3), v_ax2 %
T.int64(2), v_ax3 % T.int64(3)])
+ T.writes(T_tile[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_tile[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax1 %
T.int64(3), v_ax2 % T.int64(2), v_ax3 % T.int64(3)]
+
+ @R.function
+ def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((2, 3,
4, 9), dtype="float32"):
+ gv = R.call_tir(tile, (x,), out_sinfo=R.Tensor((2, 3, 4, 9),
dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Tile)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_tile_symbolic():
+ # fmt: off
+ @I.ir_module
+ class Tile:
+ @R.function
+ def main(x: R.Tensor(("a", "b", "c"), "float32")):
+ gv = R.tile(x, (2, 1, 2, 3))
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def tile(var_rxplaceholder: T.handle, var_T_tile: T.handle):
+ T.func_attr({"tir.noalias": True})
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c))
+ T_tile = T.match_buffer(var_T_tile, (T.int64(2), a, b *
T.int64(2), c * T.int64(3)))
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), a, b * T.int64(2), c
* T.int64(3)):
+ with T.block("T_tile"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[v_ax1 % a, v_ax2 % b, v_ax3 % c])
+ T.writes(T_tile[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_tile[v_ax0, v_ax1, v_ax2, v_ax3] = rxplaceholder[v_ax1 %
a, v_ax2 % b, v_ax3 % c]
+
+ @R.function
+ def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) ->
R.Tensor((2, "a", "b * 2", "c * 3"), dtype="float32"):
+ a = T.Var("a", "int64")
+ b = T.Var("b", "int64")
+ c = T.Var("c", "int64")
+ gv = R.call_tir(tile, (x,), out_sinfo=R.Tensor((2, a, b * 2, c *
3), dtype="float32"))
+ return gv
+ # fmt: on
+ mod = LegalizeOps()(Tile)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 8fb398f15d..90e89944f7 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -18,7 +18,7 @@
import pytest
import tvm
from tvm.relax.transform import LegalizeOps
-from tvm.script import relax as R, tir as T
+from tvm.script import relax as R, tir as T, ir as I
import tvm.testing
@@ -207,6 +207,188 @@ def test_conv2d_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_conv2d_transpose():
+ # fmt: off
+ @I.ir_module
+ class Conv2dTranspose:
+ @R.function
+ def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((128,
16, 3, 3), "float32")):
+ gv = R.nn.conv2d_transpose(x, w, strides=(2, 3), padding=(1, 1),
dilation=(1, 1), output_padding=(1, 2), groups=8)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 128, 28, 28), dtype="float32"), w:
R.Tensor((128, 16, 3, 3), dtype="float32")) -> R.Tensor((2, 128, 56, 84),
dtype="float32"):
+ gv = R.call_tir(conv2d_transpose, (x, w), out_sinfo=R.Tensor((2,
128, 56, 84), dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def conv2d_transpose(rxplaceholder: T.Buffer((T.int64(2),
T.int64(128), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1:
T.Buffer((T.int64(128), T.int64(16), T.int64(3), T.int64(3)), "float32"),
compute: T.Buffer((T.int64(2), T.int64(128), T.int64(56), T.int64(84)),
"float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ data_dilate = T.alloc_buffer((T.int64(2), T.int64(128),
T.int64(55), T.int64(82)))
+ data_pad = T.alloc_buffer((T.int64(2), T.int64(128), T.int64(58),
T.int64(86)))
+ kernel_transform = T.alloc_buffer((T.int64(16), T.int64(128),
T.int64(3), T.int64(3)))
+ for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128),
T.int64(55), T.int64(82)):
+ with T.block("data_dilate"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ T.reads(rxplaceholder[v_i0, v_i1, v_i2 // T.int64(2), v_i3
// T.int64(3)])
+ T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3])
+ data_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2
% T.int64(2) == T.int64(0) and v_i3 % T.int64(3) == T.int64(0),
rxplaceholder[v_i0, v_i1, v_i2 // T.int64(2), v_i3 // T.int64(3)], T.float32(0))
+ for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128),
T.int64(58), T.int64(86)):
+ with T.block("data_pad"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(1), v_i3 -
T.int64(1)])
+ T.writes(data_pad[v_i0, v_i1, v_i2, v_i3])
+ data_pad[v_i0, v_i1, v_i2, v_i3] =
T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(56) and T.int64(1) <= v_i3
and v_i3 < T.int64(83), data_dilate[v_i0, v_i1, v_i2 - T.int64(1), v_i3 -
T.int64(1)], T.float32(0))
+ for i, o, h, w in T.grid(T.int64(16), T.int64(128), T.int64(3),
T.int64(3)):
+ with T.block("kernel_transform"):
+ v_i, v_o, v_h, v_w = T.axis.remap("SSSS", [i, o, h, w])
+ T.reads(rxplaceholder_1[v_o, v_i, T.int64(2) - v_h,
T.int64(2) - v_w])
+ T.writes(kernel_transform[v_i, v_o, v_h, v_w])
+ kernel_transform[v_i, v_o, v_h, v_w] =
rxplaceholder_1[v_o, v_i, T.int64(2) - v_h, T.int64(2) - v_w]
+ for b, c, h, w, dc, dh, dw in T.grid(T.int64(2), T.int64(128),
T.int64(56), T.int64(84), T.int64(16), T.int64(3), T.int64(3)):
+ with T.block("compute"):
+ v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw =
T.axis.remap("SSSSRRR", [b, c, h, w, dc, dh, dw])
+ T.reads(data_pad[v_b, v_c // T.int64(16) * T.int64(16) +
v_dc, v_h + v_dh, v_w + v_dw], kernel_transform[v_c % T.int64(16), v_c //
T.int64(16) * T.int64(16) + v_dc, v_dh, v_dw])
+ T.writes(compute[v_b, v_c, v_h, v_w])
+ with T.init():
+ compute[v_b, v_c, v_h, v_w] = T.float32(0)
+ compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w]
+ data_pad[v_b, v_c // T.int64(16) * T.int64(16) + v_dc, v_h + v_dh, v_w +
v_dw] * kernel_transform[v_c % T.int64(16), v_c // T.int64(16) * T.int64(16) +
v_dc, v_dh, v_dw]
+ # fmt: on
+
+ mod = LegalizeOps()(Conv2dTranspose)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_conv2d_transpose_with_out_dtype():
+ # fmt: off
+ @tvm.script.ir_module
+ class Conv2dTranspose:
+ @R.function
+ def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((3, 4, 3,
3), "float32")):
+ gv = R.nn.conv2d_transpose(x, w, out_dtype="float16")
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3,
4, 3, 3), dtype="float32")) -> R.Tensor((2, 4, 30, 30), dtype="float16"):
+ gv = R.call_tir(conv2d_transpose, (x, w), out_sinfo=R.Tensor((2,
4, 30, 30), dtype="float16"))
+ return gv
+
+ @T.prim_func
+ def conv2d_transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3),
T.int64(4), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2),
T.int64(4), T.int64(30), T.int64(30)), "float16")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ data_dilate = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28),
T.int64(28)))
+ data_pad = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(32),
T.int64(32)))
+ kernel_transform = T.alloc_buffer((T.int64(4), T.int64(3),
T.int64(3), T.int64(3)))
+ for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28),
T.int64(28)):
+ with T.block("data_dilate"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_i3])
+ T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3])
+ data_dilate[v_i0, v_i1, v_i2, v_i3] = rxplaceholder[v_i0,
v_i1, v_i2, v_i3]
+ for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32),
T.int64(32)):
+ with T.block("data_pad"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ T.reads(data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 -
T.int64(2)])
+ T.writes(data_pad[v_i0, v_i1, v_i2, v_i3])
+ data_pad[v_i0, v_i1, v_i2, v_i3] =
T.if_then_else(T.int64(2) <= v_i2 and v_i2 < T.int64(30) and T.int64(2) <= v_i3
and v_i3 < T.int64(30), data_dilate[v_i0, v_i1, v_i2 - T.int64(2), v_i3 -
T.int64(2)], T.float32(0))
+ for o, i, h, w in T.grid(T.int64(4), T.int64(3), T.int64(3),
T.int64(3)):
+ with T.block("kernel_transform"):
+ v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h, w])
+ T.reads(rxplaceholder_1[v_i, v_o, T.int64(2) - v_h,
T.int64(2) - v_w])
+ T.writes(kernel_transform[v_o, v_i, v_h, v_w])
+ kernel_transform[v_o, v_i, v_h, v_w] =
rxplaceholder_1[v_i, v_o, T.int64(2) - v_h, T.int64(2) - v_w]
+ for b, c, h, w, dc, dh, dw in T.grid(T.int64(2), T.int64(4),
T.int64(30), T.int64(30), T.int64(3), T.int64(3), T.int64(3)):
+ with T.block("compute"):
+ v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw =
T.axis.remap("SSSSRRR", [b, c, h, w, dc, dh, dw])
+ T.reads(data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw],
kernel_transform[v_c, v_dc, v_dh, v_dw])
+ T.writes(compute[v_b, v_c, v_h, v_w])
+ with T.init():
+ compute[v_b, v_c, v_h, v_w] = T.float16(0)
+ compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w]
+ T.Cast("float16", data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw]) *
T.Cast("float16", kernel_transform[v_c, v_dc, v_dh, v_dw])
+ # fmt: on
+
+ mod = LegalizeOps()(Conv2dTranspose)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_conv2d_transpose_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class Conv2dTranspose:
+ @R.function
+ def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel:
R.Tensor(("f", "c", "kh", "kw"), "float32")):
+ gv = R.nn.conv2d_transpose(x, kernel, strides=(3, 3))
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor(("n", "c", "h", "w"), dtype="float32"), kernel:
R.Tensor(("f", "c", "kh", "kw"), dtype="float32")) -> R.Tensor(("n", "c", "h *
3 + kh - 3", "w * 3 + kw - 3"), dtype="float32"):
+ n = T.Var("n", "int64")
+ c = T.Var("c", "int64")
+ h = T.Var("h", "int64")
+ kh = T.Var("kh", "int64")
+ w = T.Var("w", "int64")
+ kw = T.Var("kw", "int64")
+ f = T.Var("f", "int64")
+ gv = R.call_tir(conv2d_transpose, (x, kernel),
out_sinfo=R.Tensor((n, c, h * 3 + kh - 3, w * 3 + kw - 3), dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def conv2d_transpose(var_rxplaceholder: T.handle, var_rxplaceholder_1:
T.handle, var_compute: T.handle):
+ T.func_attr({"tir.noalias": True})
+ n = T.var("int64")
+ c = T.var("int64")
+ h = T.var("int64")
+ w = T.var("int64")
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, h, w))
+ f = T.var("int64")
+ kh = T.var("int64")
+ kw = T.var("int64")
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (f, c, kh,
kw))
+ compute = T.match_buffer(var_compute, (n, c, h * T.int64(3) + kh -
T.int64(3), w * T.int64(3) + kw - T.int64(3)))
+ # with T.block("root"):
+ data_dilate = T.alloc_buffer((n, c, h * T.int64(3) - T.int64(2), w
* T.int64(3) - T.int64(2)))
+ data_pad = T.alloc_buffer((n, c, h * T.int64(3) + kh * T.int64(2)
- T.int64(4), w * T.int64(3) + kw * T.int64(2) - T.int64(4)))
+ kernel_transform = T.alloc_buffer((c, c, kh, kw))
+ for i0, i1, i2, i3 in T.grid(n, c, h * T.int64(3) - T.int64(2), w
* T.int64(3) - T.int64(2)):
+ with T.block("data_dilate"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ T.reads(rxplaceholder[v_i0, v_i1, v_i2 // T.int64(3), v_i3
// T.int64(3)])
+ T.writes(data_dilate[v_i0, v_i1, v_i2, v_i3])
+ data_dilate[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(v_i2
% T.int64(3) == T.int64(0) and v_i3 % T.int64(3) == T.int64(0),
rxplaceholder[v_i0, v_i1, v_i2 // T.int64(3), v_i3 // T.int64(3)], T.float32(0))
+ for i0, i1, i2, i3 in T.grid(n, c, h * T.int64(3) + kh *
T.int64(2) - T.int64(4), w * T.int64(3) + kw * T.int64(2) - T.int64(4)):
+ with T.block("data_pad"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ T.reads(data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh,
v_i3 + T.int64(1) - kw])
+ T.writes(data_pad[v_i0, v_i1, v_i2, v_i3])
+ data_pad[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(kh -
T.int64(1) <= v_i2 and v_i2 < h * T.int64(3) + kh - T.int64(3) and kw -
T.int64(1) <= v_i3 and v_i3 < w * T.int64(3) + kw - T.int64(3),
data_dilate[v_i0, v_i1, v_i2 + T.int64(1) - kh, v_i3 + T.int64(1) - kw],
T.float32(0))
+ for o, i, h_1, w_1 in T.grid(c, c, kh, kw):
+ with T.block("kernel_transform"):
+ v_o, v_i, v_h, v_w = T.axis.remap("SSSS", [o, i, h_1, w_1])
+ T.reads(rxplaceholder_1[v_i, v_o, kh - v_h - T.int64(1),
kw - v_w - T.int64(1)])
+ T.writes(kernel_transform[v_o, v_i, v_h, v_w])
+ kernel_transform[v_o, v_i, v_h, v_w] =
rxplaceholder_1[v_i, v_o, kh - v_h - T.int64(1), kw - v_w - T.int64(1)]
+ for b, c_1, h_1, w_1, dc, dh, dw in T.grid(n, c, h * T.int64(3) +
kh - T.int64(3), w * T.int64(3) + kw - T.int64(3), c, kh, kw):
+ with T.block("compute"):
+ v_b, v_c, v_h, v_w, v_dc, v_dh, v_dw =
T.axis.remap("SSSSRRR", [b, c_1, h_1, w_1, dc, dh, dw])
+ T.reads(data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw],
kernel_transform[v_c, v_dc, v_dh, v_dw])
+ T.writes(compute[v_b, v_c, v_h, v_w])
+ with T.init():
+ compute[v_b, v_c, v_h, v_w] = T.float32(0)
+ compute[v_b, v_c, v_h, v_w] = compute[v_b, v_c, v_h, v_w]
+ data_pad[v_b, v_dc, v_h + v_dh, v_w + v_dw] * kernel_transform[v_c, v_dc,
v_dh, v_dw]
+ # fmt: on
+
+ mod = LegalizeOps()(Conv2dTranspose)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_max_pool2d():
# fmt: off
@tvm.script.ir_module
@@ -345,6 +527,169 @@ def test_max_pool2d_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_avg_pool2d():
+ # fmt: off
+ @tvm.script.ir_module
+ class AvgPool2D:
+ @R.function
+ def main(x: R.Tensor((4, 112, 112, 6), "float32")) -> R.Tensor((4, 56,
56, 6), "float32"):
+ gv: R.Tensor((4, 56, 56, 6), "float32") = R.nn.avg_pool2d(x,
pool_size=[3, 3], strides=[2, 2], dilation=[1, 1], padding=[1, 1, 1, 1],
layout="NHWC")
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112),
T.int64(112), T.int64(6)), "float32"), pool_avg: T.Buffer((T.int64(4),
T.int64(56), T.int64(56), T.int64(6)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ pad_temp = T.alloc_buffer((T.int64(4), T.int64(114), T.int64(114),
T.int64(6)))
+ pool_sum = T.alloc_buffer((T.int64(4), T.int64(56), T.int64(56),
T.int64(6)))
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(114),
T.int64(114), T.int64(6)):
+ with T.block("pad_temp"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[v_ax0, v_ax1 - T.int64(1), v_ax2 -
T.int64(1), v_ax3])
+ T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3])
+ pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] =
T.if_then_else(T.int64(1) <= v_ax1 and v_ax1 < T.int64(113) and T.int64(1) <=
v_ax2 and v_ax2 < T.int64(113), rxplaceholder[v_ax0, v_ax1 - T.int64(1), v_ax2
- T.int64(1), v_ax3], T.float32(0))
+ for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(4),
T.int64(56), T.int64(56), T.int64(6), T.int64(3), T.int64(3)):
+ with T.block("pool_sum"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 =
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
+ T.reads(pad_temp[v_ax0, v_ax1 * T.int64(2) + v_rv0, v_ax2
* T.int64(2) + v_rv1, v_ax3])
+ T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
+ with T.init():
+ pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
+ pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0,
v_ax1, v_ax2, v_ax3] + pad_temp[v_ax0, v_ax1 * T.int64(2) + v_rv0, v_ax2 *
T.int64(2) + v_rv1, v_ax3]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(56),
T.int64(56), T.int64(6)):
+ with T.block("pool_avg"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
+ pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0,
v_ax1, v_ax2, v_ax3] / T.Cast("float32", (T.min(T.int64(1), T.int64(112) -
v_ax1 * T.int64(2)) + T.int64(2)) * (T.min(T.int64(1), T.int64(112) - v_ax2 *
T.int64(2)) + T.int64(2)))
+
+ @R.function
+ def main(x: R.Tensor((4, 112, 112, 6), dtype="float32")) ->
R.Tensor((4, 56, 56, 6), dtype="float32"):
+ gv = R.call_tir(avg_pool2d, (x,), out_sinfo=R.Tensor((4, 56, 56,
6), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(AvgPool2D)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_avg_pool2d_NCHW16c():
+ # fmt: off
+ @tvm.script.ir_module
+ class AvgPool2D:
+ @R.function
+ def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4,
4, 110, 110, 16), "float32"):
+ gv: R.Tensor((4, 4, 110, 110, 16), "float32") = R.nn.avg_pool2d(x,
pool_size=[3, 3], layout="NCHW16c")
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4),
T.int64(112), T.int64(112), T.int64(16)), "float32"), pool_avg:
T.Buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)),
"float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ pool_sum = T.alloc_buffer((T.int64(4), T.int64(4), T.int64(110),
T.int64(110), T.int64(16)))
+ for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(T.int64(4),
T.int64(4), T.int64(110), T.int64(110), T.int64(16), T.int64(3), T.int64(3)):
+ with T.block("pool_sum"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_ax4, v_rv0, v_rv1 =
T.axis.remap("SSSSSRR", [ax0, ax1, ax2, ax3, ax4, rv0, rv1])
+ T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2 + v_rv0, v_ax3 +
v_rv1, v_ax4])
+ T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
+ with T.init():
+ pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] =
T.float32(0)
+ pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] =
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] + rxplaceholder[v_ax0, v_ax1, v_ax2
+ v_rv0, v_ax3 + v_rv1, v_ax4]
+ for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(4), T.int64(4),
T.int64(110), T.int64(110), T.int64(16)):
+ with T.block("pool_avg"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS",
[ax0, ax1, ax2, ax3, ax4])
+ T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
+ T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
+ T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
+ pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] =
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] / T.Cast("float32",
(T.min(T.int64(2), T.int64(111) - v_ax2) + T.int64(1)) * (T.min(T.int64(2),
T.int64(111) - v_ax3) + T.int64(1)))
+
+ @R.function
+ def main(x: R.Tensor((4, 4, 112, 112, 16), dtype="float32")) ->
R.Tensor((4, 4, 110, 110, 16), dtype="float32"):
+ gv = R.call_tir(avg_pool2d, (x,), out_sinfo=R.Tensor((4, 4, 110,
110, 16), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(AvgPool2D)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_avg_pool2d_ceil_mode():
+ # fmt: off
+ @tvm.script.ir_module
+ class AvgPool2D:
+ @R.function
+ def main(x: R.Tensor((4, 6, 112, 112), "float32")) -> R.Tensor((4, 6,
38, 38), "float32"):
+ gv: R.Tensor((4, 6, 38, 38), "float32") = R.nn.avg_pool2d(x,
pool_size=[3, 3], strides=[3, 3], dilation=[1, 1], padding=[1, 1, 1, 1],
ceil_mode=True)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6),
T.int64(112), T.int64(112)), "float32"), pool_avg: T.Buffer((T.int64(4),
T.int64(6), T.int64(38), T.int64(38)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ pad_temp = T.alloc_buffer((T.int64(4), T.int64(6), T.int64(116),
T.int64(116)))
+ pool_sum = T.alloc_buffer((T.int64(4), T.int64(6), T.int64(38),
T.int64(38)))
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(6),
T.int64(116), T.int64(116)):
+ with T.block("pad_temp"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2 - T.int64(1),
v_ax3 - T.int64(1)])
+ T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3])
+ pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] =
T.if_then_else(T.int64(1) <= v_ax2 and v_ax2 < T.int64(113) and T.int64(1) <=
v_ax3 and v_ax3 < T.int64(113), rxplaceholder[v_ax0, v_ax1, v_ax2 - T.int64(1),
v_ax3 - T.int64(1)], T.float32(0))
+ for ax0, ax1, ax2, ax3, rv0, rv1 in T.grid(T.int64(4), T.int64(6),
T.int64(38), T.int64(38), T.int64(3), T.int64(3)):
+ with T.block("pool_sum"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 =
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
+ T.reads(pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(3) + v_rv0,
v_ax3 * T.int64(3) + v_rv1])
+ T.writes(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
+ with T.init():
+ pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
+ pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0,
v_ax1, v_ax2, v_ax3] + pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(3) + v_rv0, v_ax3
* T.int64(3) + v_rv1]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(6),
T.int64(38), T.int64(38)):
+ with T.block("pool_avg"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
+ pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0,
v_ax1, v_ax2, v_ax3] / T.Cast("float32", (T.min(T.int64(1), T.int64(112) -
v_ax2 * T.int64(3)) + T.int64(2)) * (T.min(T.int64(1), T.int64(112) - v_ax3 *
T.int64(3)) + T.int64(2)))
+
+ @R.function
+ def main(x: R.Tensor((4, 6, 112, 112), dtype="float32")) ->
R.Tensor((4, 6, 38, 38), dtype="float32"):
+ gv = R.call_tir(avg_pool2d, (x,), out_sinfo=R.Tensor((4, 6, 38,
38), dtype="float32"))
+ return gv
+
+ # fmt: on
+
+ mod = LegalizeOps()(AvgPool2D)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
[email protected]("TOPI pooling casts every shape value to i32.")
+def test_avg_pool2d_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class AvgPool2D:
+ @R.function
+ def main(dumb_param: R.Tensor(("kh", "kw")), x: R.Tensor(("n", "c",
"h", "w"), "float32")) -> R.Tensor(("n", "c", "h - kh + 1", "w - kw + 1"),
"float32"):
+ n = T.int64()
+ c = T.int64()
+ h = T.int64()
+ w = T.int64()
+ kh = T.int64()
+ kw = T.int64()
+ gv: R.Tensor((n, c, h - kh + 1, w - kw + 1), "float32") =
R.nn.avg_pool2d(x, pool_size=[kh, kw])
+ return gv
+
+ # fmt: on
+
+ mod = LegalizeOps()(AvgPool2D)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_adaptive_avg_pool2d():
# fmt: off
@tvm.script.ir_module
diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py
b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
index c1d0c90d34..10b786ee4a 100644
--- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py
+++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
@@ -343,5 +343,50 @@ def test_collapse_sum_to():
_check(foo, bb.get()["foo"])
+def test_repeat():
+ @R.function
+ def foo(x: R.Tensor((2, 3, 4), "float32")):
+ gv = R.repeat(x, 3, 1)
+ return gv
+
+ x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x]):
+ gv = bb.emit(relax.op.repeat(x, 3, 1))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
+def test_repeat_no_axis():
+ @R.function
+ def foo(x: R.Tensor((2, 3, 4), "float32")):
+ gv = R.repeat(x, 3)
+ return gv
+
+ x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x]):
+ gv = bb.emit(relax.op.repeat(x, 3))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
+def test_tile():
+ @R.function
+ def foo(x: R.Tensor((2, 3, 4), "float32")):
+ gv = R.tile(x, (2, 3))
+ return gv
+
+ x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x]):
+ gv = bb.emit(relax.op.tile(x, (2, 3)))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py
b/tests/python/relax/test_tvmscript_parser_op_nn.py
index c2bfa5b7a9..cfb454a578 100644
--- a/tests/python/relax/test_tvmscript_parser_op_nn.py
+++ b/tests/python/relax/test_tvmscript_parser_op_nn.py
@@ -53,6 +53,26 @@ def test_conv2d():
_check(foo, bb.get()["foo"])
+def test_conv2d_transpose():
+ @R.function
+ def foo(
+ x: R.Tensor((2, 3, 228, 228), "float32"), w: R.Tensor((3, 16, 5, 5),
"float32")
+ ) -> R.Tensor((2, 16, 232, 232), "float16"):
+ gv: R.Tensor((2, 16, 232, 232), "float16") = R.nn.conv2d_transpose(
+ x, w, out_dtype="float16"
+ )
+ return gv
+
+ x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float32"))
+ w = relax.Var("w", R.Tensor([3, 16, 5, 5], "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x, w]):
+ gv = bb.emit(relax.op.nn.conv2d_transpose(x, w, out_dtype="float16"))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
def test_max_pool2d():
@R.function
def foo(
@@ -70,6 +90,23 @@ def test_max_pool2d():
_check(foo, bb.get()["foo"])
+def test_avg_pool2d():
+ @R.function
+ def foo(
+ x: R.Tensor((1, 1, 32, 32), dtype="float32")
+ ) -> R.Tensor((1, 1, 30, 30), dtype="float32"):
+ gv: R.Tensor((1, 1, 30, 30), dtype="float32") = R.nn.avg_pool2d(x,
pool_size=(3,))
+ return gv
+
+ x = relax.Var("x", R.Tensor([1, 1, 32, 32], "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x]):
+ gv = bb.emit(relax.op.nn.avg_pool2d(x, pool_size=(3,)))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
def test_adaptive_avg_pool2d():
@R.function
def foo(x: R.Tensor((2, 64, 8, 9), "float32")) -> R.Tensor((2, 64, 7, 7),
"float32"):