This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e0501b527d [Unity][Frontend][NN] Collection of new NN operators 
(#15642)
e0501b527d is described below

commit e0501b527daf53f93677a502dca6ed91517feadb
Author: Josh Fromm <[email protected]>
AuthorDate: Wed Aug 30 15:19:22 2023 -0700

    [Unity][Frontend][NN] Collection of new NN operators (#15642)
    
    This PR adds a bunch of new NN operators that conform to torch semantics. I 
also updated a couple of existing operators like LayerNorm to be more inline 
with the torch equivalent. Specifically this PR adds
    * unsqueeze
    * concat
    * subtract
    * Identity
    * chunk
    * pad
    I also fixed an issue in ReshapeDataflowRewrite when operating on scalars.
---
 python/tvm/relax/frontend/nn/modules.py            |  40 ++-
 python/tvm/relax/frontend/nn/op.py                 | 285 +++++++++++++++++++--
 python/tvm/relax/op/image/image.py                 |   2 +-
 src/relax/transform/rewrite_dataflow_reshape.cc    |   9 +-
 tests/python/relax/test_frontend_nn_modules.py     |  20 +-
 tests/python/relax/test_frontend_nn_op.py          | 119 +++++++--
 tests/python/relax/test_op_image.py                |   2 -
 .../test_transform_rewrite_dataflow_reshape.py     |  54 ++++
 8 files changed, 479 insertions(+), 52 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/modules.py 
b/python/tvm/relax/frontend/nn/modules.py
index 6df0957398..3eb3c569dc 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -78,6 +78,24 @@ class SiLU(Module):
         return op.silu(x)
 
 
+class Identity(Module):
+    """Module that does nothing, sometimes useful for naming purposes."""
+
+    def forward(self, x: Tensor):
+        """Forward method for identity.
+
+        Parameters
+        ----------
+        x : Tensor
+            The input tensor.
+        Returns
+        -------
+        Result : Tensor
+            The unchanged input tensor.
+        """
+        return x
+
+
 class Linear(Module):
     """
     Module for linear layer.
@@ -250,15 +268,20 @@ class LayerNorm(Module):
     def __init__(
         self,
         normalized_shape: int,
-        axes: Union[int, List[int]],
         eps: Optional[float] = 1e-5,
+        elementwise_affine: bool = True,
         dtype: Optional[str] = None,
     ) -> None:
         super().__init__()
+        self.normalized_shape = normalized_shape
         self.eps = eps
-        self.axes = axes
-        self.weight = Parameter((normalized_shape,), dtype=dtype)
-        self.bias = Parameter((normalized_shape,), dtype=dtype)
+        self.elementwise_affine = elementwise_affine
+        if self.elementwise_affine:
+            self.weight = Parameter((normalized_shape,), dtype=dtype)
+            self.bias = Parameter((normalized_shape,), dtype=dtype)
+        else:
+            self.weight = None
+            self.bias = None
 
     def forward(self, x: Tensor) -> Tensor:
         """
@@ -274,8 +297,13 @@ class LayerNorm(Module):
         ret : Tensor
             The output tensor for the layer normalization layer.
         """
-        out = op.layer_norm(x, weight=self.weight, bias=self.bias, 
axes=self.axes, epsilon=self.eps)
-        return out
+        return op.layer_norm(
+            x,
+            normalized_shape=self.normalized_shape,
+            weight=self.weight,
+            bias=self.bias,
+            eps=self.eps,
+        )
 
 
 class RMSNorm(Module):
diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 4ef02797c2..5473fcb499 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -14,18 +14,19 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=too-many-lines,invalid-name,protected-access
+# pylint: 
disable=too-many-lines,invalid-name,protected-access,redefined-outer-name
 """nn.Tensor operators."""
 import math
 from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
 
+import numpy as np
 from tvm import tir as _tir
 
 from ... import expr as rx
 from ... import op as _op
 from ...block_builder import BlockBuilder
 from ...struct_info import TensorStructInfo, TupleStructInfo
-from .core import Tensor
+from .core import Tensor, get_default_dtype
 from .spec import SpecBuilder
 
 IntExpr = Union[int, _tir.PrimExpr]
@@ -57,11 +58,53 @@ def _wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, 
Tuple[Tensor]]:
                 rx.TupleGetItem(expr, i),
                 name=f"{name}.{i}",
             )
-            for i in range(expr.struct_info_.fields)
+            for i in range(len(expr.struct_info_.fields))
         )
     raise TypeError(f"Unsupported return type: {expr.struct_info_}")
 
 
+def unsqueeze(x: Tensor, dim: int, name: str = "unsqueeze") -> Tensor:
+    """Add a new axis to a tensor
+
+    Parameters
+    ----------
+    x : Tensor
+        Input tensor to expand.
+    dim : int
+        Dimension to expand.
+    name : str
+        Name hint for this operator.
+
+    Returns
+    -------
+    result : Tensor
+        Expanded result.
+    """
+    return _wrap_nested(_op.expand_dims(x._expr, dim), name)
+
+
+def concat(x: List[Tensor], dim: int, name: str = "concat") -> Tensor:
+    """Concatenate a list of tensors along an axis.
+
+    Parameters
+    ----------
+    x : List[Tensor]
+        List of tensors to concatenate.
+    dim : int
+        Dimension to concatenate upon.
+    name : str
+        Name hint for this operator.
+
+    Returns
+    -------
+    result : Tensor
+        Expanded result.
+    """
+    # Convert tensors to expressions.
+    x = [t._expr for t in x]
+    return _wrap_nested(_op.concat(x, dim), name)
+
+
 def add(a: Tensor, b: Tensor, name: str = "add") -> Tensor:
     """Addition with numpy-style broadcasting.
 
@@ -90,6 +133,34 @@ def add(a: Tensor, b: Tensor, name: str = "add") -> Tensor:
     return _wrap_nested(_op.add(a._expr, b._expr), name)
 
 
+def subtract(a: Tensor, b: Tensor, name: str = "subtract") -> Tensor:
+    """Subtraction with numpy-style broadcasting.
+
+    Parameters
+    ----------
+    a : Tensor
+        The first input tensor.
+
+    b : Tensor
+        The second input tensor.
+
+    name : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+
+    Examples
+    --------
+    .. code:: python
+
+        c = subtract(a, b)
+    """
+    return _wrap_nested(_op.subtract(a._expr, b._expr), name)
+
+
 def multiply(a: Tensor, b: Tensor, name: str = "mul") -> Tensor:
     """Multiplication with numpy-style broadcasting.
 
@@ -146,6 +217,28 @@ def divide(a: Tensor, b: Tensor, name: str = "divide") -> 
Tensor:
     return _wrap_nested(_op.divide(a._expr, b._expr), name)
 
 
+def chunk(x: Tensor, chunks: int, dim: int = 0, name: str = "chunk") -> Tensor:
+    """Split a tensor along dim into the specified number of chunks.
+
+    Parameters
+    ----------
+    x : Tensor
+        Input tensor to be split.
+    chunks : int
+        Number of pieces to slice x into.
+    dim : int
+        Which dimension to split x.
+    name : str
+        Name hint for this operation.
+
+    Returns
+    -------
+    result : Tuple[Tensor]
+        A tuple with chunks elements containing slices of x.
+    """
+    return _wrap_nested(_op.split(x._expr, chunks, dim), name)
+
+
 def matmul(a: Tensor, b: Tensor, out_dtype: Optional[str] = None, name: str = 
"matmul") -> Tensor:
     """General matrix multiplication of two tensors, with broadcasting on 
batched dimensions.
 
@@ -569,6 +662,9 @@ def astype(x: Tensor, dtype: str, name: str = "astype") -> 
Tensor:
     result : Tensor
         The casted result.
     """
+    # If trying to cast to same dtype as x, skip casting.
+    if x.dtype == dtype:
+        return x
     return _wrap_nested(_op.astype(x._expr, dtype), name)
 
 
@@ -598,7 +694,7 @@ def silu(x: Tensor, name: str = "silu") -> Tensor:
     return _wrap_nested(_op.nn.silu(x._expr), name)
 
 
-def gelu(x: Tensor, name: str = "gelu") -> Tensor:
+def gelu(x: Tensor, approximate: Optional[str] = None, name: str = "gelu") -> 
Tensor:
     r"""Applies the Gaussian Error Linear Units function
 
     .. math::
@@ -611,7 +707,10 @@ def gelu(x: Tensor, name: str = "gelu") -> Tensor:
     x : Tensor
         The input data
 
-    naem : str
+    approximate : Optional[str]
+        If set to tanh, use an approximation when calculating CDF.
+
+    name : str
         Name hint.
 
     Returns
@@ -623,7 +722,20 @@ def gelu(x: Tensor, name: str = "gelu") -> Tensor:
     ----
     The input tensor is required to have float dtype
     """
-    return _wrap_nested(_op.nn.gelu(x._expr), name)
+    dtype = x._expr.struct_info.dtype
+    if approximate == "tanh":
+        tanh_const = rx.const(1 + np.tanh(np.sqrt(2 / np.pi)), dtype=dtype)
+        gelu_out = (
+            rx.const(0.5, dtype)
+            * x._expr
+            * (
+                tanh_const
+                * (x._expr + (rx.const(0.044715, dtype) * _op.power(x._expr, 
rx.const(3, "int32"))))
+            )
+        )
+    else:
+        gelu_out = _op.nn.gelu(x._expr)
+    return _wrap_nested(gelu_out, name)
 
 
 def softmax(x: Tensor, axis: int = -1, name: str = "softmax") -> Tensor:
@@ -658,10 +770,10 @@ def softmax(x: Tensor, axis: int = -1, name: str = 
"softmax") -> Tensor:
 
 def layer_norm(
     x: Tensor,
-    weight: Tensor,
-    bias: Tensor,
-    axes: Union[int, List[int]],
-    epsilon: float = 1e-5,
+    normalized_shape: Union[int, List[int]],
+    weight: Optional[Tensor] = None,
+    bias: Optional[Tensor] = None,
+    eps: float = 1e-5,
     name: str = "layer_norm",
 ) -> Tensor:
     r"""
@@ -685,19 +797,21 @@ def layer_norm(
 
     Parameters
     ----------
-    data : Tensor
+    x : Tensor
         Input to which layer_norm will be applied.
 
-    gamma : Tensor
+    normalized_shape: Union[int, List[int]]
+        The shape of axes to normalize. If a single integer
+        is used, it is treated as a singleton list and this
+        module will normalize over the last dimension.
+
+    weight: Tensor
         The gamma scale factor.
 
-    beta : Tensor
+    bias: Tensor
         The beta offset factor.
 
-    axes : Union[int, List[int]]
-        The axes that along which the normalization is applied.
-
-    epsilon : float
+    eps: float
         Small float added to variance to avoid dividing by zero.
 
     name : str
@@ -708,7 +822,31 @@ def layer_norm(
     result : Tensor
         The computed result.
     """
-    return _wrap_nested(_op.nn.layer_norm(x._expr, weight._expr, bias._expr, 
axes, epsilon), name)
+    if isinstance(normalized_shape, int):
+        normalized_shape = [normalized_shape]
+    dim_num = len(normalized_shape)
+    axes = list(range(-dim_num, 0))
+    dtype = x._expr.struct_info.dtype
+
+    if weight is not None:
+        weight = weight._expr
+    else:
+        weight = rx.const(np.ones(normalized_shape), dtype=dtype)
+    if bias is not None:
+        bias = bias._expr
+    else:
+        bias = rx.const(np.zeros(normalized_shape), dtype=dtype)
+
+    return _wrap_nested(
+        _op.nn.layer_norm(
+            x._expr,
+            gamma=weight,
+            beta=bias,
+            axes=axes,
+            epsilon=eps,
+        ),
+        name=name,
+    )
 
 
 def rms_norm(
@@ -906,6 +1044,39 @@ def zeros(
     return _wrap_nested(_op.zeros(shape, dtype), name)
 
 
+def pad(
+    x: Tensor,
+    pad: List[int],
+    mode: str = "constant",
+    value: int = 0,
+    name: str = "pad",
+) -> Tensor:
+    """
+    Apply spatial padding to the input tensor.
+
+    Parameters
+    ----------
+    x : Tensor
+        Input tensor to be padded.
+    pad : List[int]
+        List in the format of [before_0, after_0, before_1, after_1, ...]
+        indicating how much to pad each axis of x.
+    mod : str
+        Padding mode to use, constant implies padded elements will use
+        value argument.
+    value : int
+        What to pad with in constant mode.
+    name : str
+        Name hint for this operator.
+
+    Returns
+    -------
+    result : Tensor
+        Padded output tensor.
+    """
+    return _wrap_nested(_op.nn.pad(x._expr, pad_width=pad, pad_value=value, 
pad_mode=mode), name)
+
+
 def get_timestep_embedding(
     x: Tensor,
     embedding_dim: int,
@@ -940,19 +1111,20 @@ def get_timestep_embedding(
     result : Tensor
         [N x dim] Tensor of positional embeddings.
     """
-    timesteps = _op.astype(x._expr, "float32")
+    dtype = get_default_dtype()
+    timesteps = _op.astype(x._expr, dtype)
 
     half_dim = embedding_dim // 2
-    exponent = rx.const(-math.log(max_period), "float32") * _op.arange(
-        start=0, end=half_dim, dtype="float32"
+    exponent = rx.const(-math.log(max_period), dtype) * _op.arange(
+        start=0, end=half_dim, dtype=dtype
     )
-    exponent = exponent / (rx.const(half_dim - downscale_freq_shift, 
"float32"))
+    exponent = exponent / (rx.const(half_dim - downscale_freq_shift, dtype))
 
     emb = _op.exp(exponent)
     emb = _op.expand_dims(timesteps, 1) * _op.expand_dims(emb, 0)
     # Scale embeddings
     if scale != 1:
-        emb = rx.const(scale, "float32") * emb
+        emb = rx.const(scale, dtype) * emb
 
     # Concat sine and cosine embeddings.
     if flip_sin_to_cos:
@@ -1005,6 +1177,73 @@ def scaled_dot_product_attention(
     return _wrap_nested(attn, name)
 
 
+def interpolate(
+    x: Tensor,
+    size: Optional[Union[int, Tuple[int]]] = None,
+    scale_factor: Optional[Union[float, Tuple[float]]] = None,
+    mode: str = "nearest",
+    align_corners: Optional[bool] = None,
+    recompute_scale_factor: Optional[bool] = None,
+    antialias: Optional[bool] = None,
+    name: str = "interpolate",
+):
+    """Resize a tensor using the specified mode.
+
+    Parameters
+    ----------
+    x : Tensor
+        Input tensor to be resized.
+    size : Optional[Union[int, Tuple[int]]]
+        Requested output size, only one of size and scale_factor may
+        be specified.
+    scale_factor : Optional[Union[float, Tuple[float]]]
+        Multiplier for spatial size.
+    mode : str
+        Algorithm used for sampling.
+    align_corners : Optional[bool]
+        How to map pixels before and after sampling.
+    recompute_scale_factor : Optional[bool]
+        Recompute the scale_factor for use in interpolation.
+    antialias : Optional[bool]
+        Apply antialiasing to output.
+    name : str
+        Name hint for this operation.
+
+    Returns
+    -------
+    result : Tensor
+        Output tensor with requested shape.
+    """
+    assert recompute_scale_factor is None, "recompute_scale_factor is not 
supported."
+    assert antialias is None, "antialias is not supported."
+
+    if size is None:
+        shape = x.shape
+        if isinstance(scale_factor, (list, tuple)):
+            size = tuple(int(shape[i] * scale_factor[i]) for i in range(2, 
len(shape)))
+        else:
+            size = tuple(int(shape[i] * scale_factor) for i in range(2, 
len(shape)))
+
+    if mode.startswith("nearest"):
+        mode = "nearest_neighbor"
+    elif mode[0:2] == "bi":
+        mode = mode[2:]
+
+    if mode == "nearest_neighbor":
+        coord_trans = "asymmetric"
+    elif align_corners:
+        coord_trans = "align_corners"
+    else:
+        coord_trans = "half_pixel"
+
+    return _wrap_nested(
+        _op.image.resize2d(
+            x._expr, size, layout="NCHW", method=mode, 
coordinate_transformation_mode=coord_trans
+        ),
+        name,
+    )
+
+
 def tensor_expr_op(
     tensor_expr_func: Callable,
     name_hint: str,
diff --git a/python/tvm/relax/op/image/image.py 
b/python/tvm/relax/op/image/image.py
index 562de5021d..e314e9b49a 100644
--- a/python/tvm/relax/op/image/image.py
+++ b/python/tvm/relax/op/image/image.py
@@ -107,7 +107,7 @@ def resize2d(
 
     if isinstance(size, (int, PrimExpr)):
         size = (size, size)
-    if isinstance(size, tuple):
+    if isinstance(size, (tuple, list)):
         if len(size) == 1:
             size = ShapeExpr([size[0], size[0]])
         else:
diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc 
b/src/relax/transform/rewrite_dataflow_reshape.cc
index 2053a56bb7..8345f3e0b7 100644
--- a/src/relax/transform/rewrite_dataflow_reshape.cc
+++ b/src/relax/transform/rewrite_dataflow_reshape.cc
@@ -127,8 +127,13 @@ class DataflowReshapeRewriter : public ExprMutator {
       return false;
     }
     auto product = [](Array<PrimExpr> args) -> PrimExpr {
-      ICHECK(!args.empty());
-      PrimExpr p = args[0];
+      PrimExpr p;
+      if (args.empty()) {
+        // Scalar tensors may be empty indicating a single element.
+        p = 1;
+      } else {
+        p = args[0];
+      }
       for (int i = 1, e = args.size(); i < e; ++i) p *= args[i];
       return p;
     };
diff --git a/tests/python/relax/test_frontend_nn_modules.py 
b/tests/python/relax/test_frontend_nn_modules.py
index dba4178f65..86f018383f 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -43,6 +43,22 @@ def test_silu():
     assert_structural_equal(tvm_mod["forward"], forward, True)
 
 
+def test_identity():
+    @R.function
+    def forward(
+        x: R.Tensor((3, 3), dtype="float32"),
+        _io: R.Object,
+    ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
+        with R.dataflow():
+            gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)) 
= x, (_io,)
+            R.output(gv1)
+        return gv1
+
+    mod = modules.Identity()
+    tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3), 
"float32")}})
+    assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
 def test_linear():
     @R.function
     def forward(
@@ -93,13 +109,13 @@ def test_layer_norm():
     @R.function
     def forward(x: R.Tensor((2, 4, 8), dtype="float32"), weight: 
R.Tensor((8,), dtype="float32"), bias: R.Tensor((8,), dtype="float32"), _io: 
R.Object) -> R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)):
         with R.dataflow():
-            layer_norm: R.Tensor((2, 4, 8), dtype="float32") = 
R.nn.layer_norm(x, weight, bias, axes=[2], epsilon=1.0000000000000001e-05, 
center=True, scale=True)
+            layer_norm: R.Tensor((2, 4, 8), dtype="float32") = 
R.nn.layer_norm(x, weight, bias, axes=[-1], epsilon=1.0000000000000001e-05, 
center=True, scale=True)
             gv1: R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), 
R.Tuple(R.Object)) = layer_norm, (_io,)
             R.output(gv1)
         return gv1
     # fmt: on
 
-    mod = modules.LayerNorm(8, [2])
+    mod = modules.LayerNorm(8)
     tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((2, 4, 8), 
"float32")}})
     assert_structural_equal(tvm_mod["forward"], forward, True)
 
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index 6b59d0419a..255f8ab3ba 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -37,11 +37,12 @@ def test_binary():
             z3 = op.matmul(x, y)
             z4 = op.maximum(x, y)
             z5 = op.minimum(x, y)
-            return (z0, z1, z2, z3, z4, z5)
+            z6 = op.subtract(x, y)
+            return (z0, z1, z2, z3, z4, z5, z6)
 
     # fmt: off
     @R.function
-    def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1), 
dtype="float32"), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 10), 
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), 
dtype="float32"), R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10), 
dtype="float32"), R.Tensor((10, 10), dtype="float32")), R.Tuple(R.Object)):
+    def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1), 
dtype="float32"), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 10), 
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), 
dtype="float32"), R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10), 
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), 
dtype="float32")), R.Tuple(R.Object)):
         with R.dataflow():
             add: R.Tensor((10, 10), dtype="float32") = R.add(x, y)
             mul: R.Tensor((10, 10), dtype="float32") = R.multiply(x, y)
@@ -49,7 +50,8 @@ def test_binary():
             matmul: R.Tensor((1, 1), dtype="float32") = R.matmul(x, y, 
out_dtype="void")
             maximum: R.Tensor((10, 10), dtype="float32") = R.maximum(x, y)
             minimum: R.Tensor((10, 10), dtype="float32") = R.minimum(x, y)
-            gv1: R.Tuple(R.Tuple(R.Tensor((10, 10), dtype="float32"), 
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"), 
R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10), dtype="float32"), 
R.Tensor((10, 10), dtype="float32")), R.Tuple(R.Object)) = (add, mul, divide, 
matmul, maximum, minimum), (_io,)
+            subtract: R.Tensor((10, 10), dtype="float32") = R.subtract(x, y)
+            gv1: R.Tuple(R.Tuple(R.Tensor((10, 10), dtype="float32"), 
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"), 
R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10), dtype="float32"), 
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")), 
R.Tuple(R.Object)) = (add, mul, divide, matmul, maximum, minimum, subtract), 
(_io,)
             R.output(gv1)
         return gv1
     # fmt: on
@@ -58,7 +60,6 @@ def test_binary():
     irmodule, _ = m.export_tvm(
         spec={"test": {"x": spec.Tensor([1, 10], "float32"), "y": 
spec.Tensor([10, 1], "float32")}}
     )
-
     tvm.ir.assert_structural_equal(irmodule["test"], test)
 
 
@@ -70,24 +71,28 @@ def test_manipulate():
             z2 = op.reshape(x, [1, 10])
             z3 = op.repeat(x, repeats=2, axis=1)
             z4 = op.squeeze(x, 0)
-            return (z0, z1, z2, z3, z4)
+            z5 = op.unsqueeze(x, 0)
+            z6 = op.concat([x, x], dim=0)
+            return (z0, z1, z2, z3, z4, z5, z6)
 
     # fmt: off
     @R.function
-    def test(x: R.Tensor((1, 5, 2), dtype="float32"), _io: R.Object) -> 
R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"), R.Tensor((2, 5, 1), 
dtype="float32"), R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10, 2), 
dtype="float32"), R.Tensor((5, 2), dtype="float32")), R.Tuple(R.Object)):
+    def test(x: R.Tensor((1, 5, 2), dtype="float32"), _io: R.Object) -> 
R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"), R.Tensor((2, 5, 1), 
dtype="float32"), R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10, 2), 
dtype="float32"), R.Tensor((5, 2), dtype="float32"), R.Tensor((1, 1, 5, 2), 
dtype="float32"), R.Tensor((2, 5, 2), dtype="float32")), R.Tuple(R.Object)):
         with R.dataflow():
             broadcast_to: R.Tensor((2, 5, 2), dtype="float32") = 
R.broadcast_to(x, R.shape([2, 5, 2]))
             permute_dims: R.Tensor((2, 5, 1), dtype="float32") = 
R.permute_dims(x, axes=[2, 1, 0])
             reshape: R.Tensor((1, 10), dtype="float32") = R.reshape(x, 
R.shape([1, 10]))
             repeat: R.Tensor((1, 10, 2), dtype="float32") = R.repeat(x, 
repeats=2, axis=1)
             squeeze: R.Tensor((5, 2), dtype="float32") = R.squeeze(x, axis=[0])
-            gv1: R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"), 
R.Tensor((2, 5, 1), dtype="float32"), R.Tensor((1, 10), dtype="float32"), 
R.Tensor((1, 10, 2), dtype="float32"), R.Tensor((5, 2), dtype="float32")), 
R.Tuple(R.Object)) = (broadcast_to, permute_dims, reshape, repeat, squeeze), 
(_io,)
+            unsqueeze: R.Tensor((1, 1, 5, 2), dtype="float32") = 
R.expand_dims(x, axis=0)
+            concat: R.Tensor((2, 5, 2), dtype="float32") = R.concat([x, x], 
axis=0)
+            gv1: R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"), 
R.Tensor((2, 5, 1), dtype="float32"), R.Tensor((1, 10), dtype="float32"), 
R.Tensor((1, 10, 2), dtype="float32"), R.Tensor((5, 2), dtype="float32"), 
R.Tensor((1, 1, 5, 2), dtype="float32"), R.Tensor((2, 5, 2), dtype="float32")), 
R.Tuple(R.Object)) = (broadcast_to, permute_dims, reshape, repeat, squeeze, 
unsqueeze, concat), (_io,)
             R.output(gv1)
         return gv1
     # fmt: on
 
     m = Model()
-    irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 5, 2], 
"float32")}})
+    irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 5, 2], 
"float32")}})
 
     tvm.ir.assert_structural_equal(irmodule["test"], test)
 
@@ -141,8 +146,10 @@ def test_datatype():
 def test_image():
     class Model(Module):
         def test(self, x: Tensor, weight: Tensor, bias: Tensor):
-            conv2d_out = op.conv2d(x, weight, bias)
-            return conv2d_out
+            padded = op.pad(x, [0, 0, 0, 0, 1, 1, 1, 1])
+            conv2d = op.conv2d(padded, weight, bias)
+            interpolate = op.interpolate(x, size=[40, 40])
+            return (conv2d, interpolate)
 
     @R.function
     def test(
@@ -150,19 +157,53 @@ def test_image():
         weight: R.Tensor((32, 3, 3, 3), dtype="float32"),
         bias: R.Tensor((32,), dtype="float32"),
         _io: R.Object,
-    ) -> R.Tuple(R.Tensor((1, 32, 30, 30), dtype="float32"), 
R.Tuple(R.Object)):
+    ) -> R.Tuple(
+        R.Tuple(
+            R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tensor((1, 3, 40, 
40), dtype="float32")
+        ),
+        R.Tuple(R.Object),
+    ):
         with R.dataflow():
-            lv1: R.Tensor((1, 32, 30, 30), dtype="float32") = R.nn.conv2d(x, 
weight)
+            lv0: R.Tensor((1, 3, 34, 34), dtype="float32") = R.nn.pad(x, (0, 
0, 0, 0, 1, 1, 1, 1))
+            lv1: R.Tensor((1, 32, 32, 32), dtype="float32") = R.nn.conv2d(
+                lv0,
+                weight,
+                strides=[1, 1],
+                padding=[0, 0, 0, 0],
+                dilation=[1, 1],
+                groups=1,
+                data_layout="NCHW",
+                kernel_layout="OIHW",
+                out_layout="NCHW",
+                out_dtype="void",
+            )
             lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(bias, 
R.shape([1, 32, 1, 1]))
-            conv2d: R.Tensor((1, 32, 30, 30), dtype="float32") = R.add(lv1, 
lv2)
-            gv1: R.Tuple(R.Tensor((1, 32, 30, 30), dtype="float32"), 
R.Tuple(R.Object)) = conv2d, (
-                _io,
+            conv2d: R.Tensor((1, 32, 32, 32), dtype="float32") = R.add(lv1, 
lv2)
+            interpolate: R.Tensor((1, 3, 40, 40), dtype="float32") = 
R.image.resize2d(
+                x,
+                R.shape([40, 40]),
+                roi=[T.float32(0), T.float32(0), T.float32(0), T.float32(0)],
+                layout="NCHW",
+                method="nearest_neighbor",
+                coordinate_transformation_mode="asymmetric",
+                rounding_method="round",
+                cubic_alpha=-0.5,
+                cubic_exclude=0,
+                extrapolation_value=0,
+                out_dtype="void",
             )
+            gv1: R.Tuple(
+                R.Tuple(
+                    R.Tensor((1, 32, 32, 32), dtype="float32"),
+                    R.Tensor((1, 3, 40, 40), dtype="float32"),
+                ),
+                R.Tuple(R.Object),
+            ) = (conv2d, interpolate), (_io,)
             R.output(gv1)
         return gv1
 
     m = Model()
-    irmodule, params = m.export_tvm(
+    irmodule, _ = m.export_tvm(
         spec={
             "test": {
                 "x": spec.Tensor([1, 3, 32, 32], "float32"),
@@ -174,6 +215,52 @@ def test_image():
     tvm.ir.assert_structural_equal(irmodule["test"], test)
 
 
+def test_chunk():
+    class Model(Module):
+        def test(self, x: Tensor):
+            chunk = op.chunk(x, chunks=4)
+            return chunk
+
+    @R.function
+    def test(
+        x: R.Tensor((8,), dtype="float32"), _io: R.Object
+    ) -> R.Tuple(
+        R.Tuple(
+            R.Tensor((2,), dtype="float32"),
+            R.Tensor((2,), dtype="float32"),
+            R.Tensor((2,), dtype="float32"),
+            R.Tensor((2,), dtype="float32"),
+        ),
+        R.Tuple(R.Object),
+    ):
+        with R.dataflow():
+            chunk: R.Tuple(
+                R.Tensor((2,), dtype="float32"),
+                R.Tensor((2,), dtype="float32"),
+                R.Tensor((2,), dtype="float32"),
+                R.Tensor((2,), dtype="float32"),
+            ) = R.split(x, indices_or_sections=4, axis=0)
+            chunk_0: R.Tensor((2,), dtype="float32") = chunk[0]
+            chunk_1: R.Tensor((2,), dtype="float32") = chunk[1]
+            chunk_2: R.Tensor((2,), dtype="float32") = chunk[2]
+            chunk_3: R.Tensor((2,), dtype="float32") = chunk[3]
+            gv1: R.Tuple(
+                R.Tuple(
+                    R.Tensor((2,), dtype="float32"),
+                    R.Tensor((2,), dtype="float32"),
+                    R.Tensor((2,), dtype="float32"),
+                    R.Tensor((2,), dtype="float32"),
+                ),
+                R.Tuple(R.Object),
+            ) = (chunk_0, chunk_1, chunk_2, chunk_3), (_io,)
+            R.output(gv1)
+        return gv1
+
+    m = Model()
+    irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([8], 
"float32")}})
+    tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
 def test_nn():
     class Model(Module):
         def test(self, x: Tensor, weight: Tensor, bias: Tensor):
diff --git a/tests/python/relax/test_op_image.py 
b/tests/python/relax/test_op_image.py
index 251a30139b..3b2db54848 100644
--- a/tests/python/relax/test_op_image.py
+++ b/tests/python/relax/test_op_image.py
@@ -244,8 +244,6 @@ def test_resize2d_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.image.resize2d(x1, size=32))
     with pytest.raises(TVMError):
         bb.normalize(relax.op.image.resize2d(x2, s0))
-    with pytest.raises(TVMError):
-        relax.op.image.resize2d(x2, [30, 30])
 
 
 if __name__ == "__main__":
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py 
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
index b02663e5eb..01065bea21 100644
--- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -446,5 +446,59 @@ def test_reshape_detect_nop():
     tvm.ir.assert_structural_equal(rewritten, Module)
 
 
+def test_reshape_scalar():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), 
dtype="float32"):
+            with R.dataflow():
+                lv1: R.Tensor((1,), dtype="float32") = R.reshape(x, [1])
+                lv2: R.Tensor((1,), dtype="float32") = R.add(lv1, lv1)
+                R.output(lv2)
+            return lv2
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def add(
+            A: T.Buffer((T.int64(1),), "float32"),
+            B: T.Buffer((T.int64(1),), "float32"),
+            T_add: T.Buffer((T.int64(1),), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for ax0 in range(T.int64(1)):
+                with T.block("T_add"):
+                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                    T.reads(A[v_ax0], B[v_ax0])
+                    T.writes(T_add[v_ax0])
+                    T_add[v_ax0] = A[v_ax0] + B[v_ax0]
+
+        @T.prim_func(private=True)
+        def reshape(A: T.Buffer((), "float32"), T_reshape: 
T.Buffer((T.int64(1),), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for ax0 in range(T.int64(1)):
+                with T.block("T_reshape"):
+                    v_ax0 = T.axis.spatial(T.int64(1), ax0)
+                    T.reads(A[()])
+                    T.writes(T_reshape[v_ax0])
+                    T_reshape[v_ax0] = A[()]
+
+        @R.function
+        def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,), 
dtype="float32"):
+            cls = Expected
+            with R.dataflow():
+                lv1: R.Tensor((1,), dtype="float32") = R.reshape(x, 
R.shape([1]))
+                lv2 = R.call_tir(cls.add, (lv1, lv1), out_sinfo=R.Tensor((1,), 
dtype="float32"))
+                R.output(lv2)
+            return lv2
+
+    mod = Module
+    mod = relax.transform.LegalizeOps()(mod)
+    rewritten = relax.transform.RewriteDataflowReshape()(mod)
+    tvm.ir.assert_structural_equal(rewritten, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to