This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e0501b527d [Unity][Frontend][NN] Collection of new NN operators
(#15642)
e0501b527d is described below
commit e0501b527daf53f93677a502dca6ed91517feadb
Author: Josh Fromm <[email protected]>
AuthorDate: Wed Aug 30 15:19:22 2023 -0700
[Unity][Frontend][NN] Collection of new NN operators (#15642)
This PR adds a bunch of new NN operators that conform to torch semantics. I
also updated a couple of existing operators like LayerNorm to be more inline
with the torch equivalent. Specifically this PR adds
* unsqueeze
* concat
* subtract
* Identity
* chunk
* pad
I also fixed an issue in ReshapeDataflowRewrite when operating on scalars.
---
python/tvm/relax/frontend/nn/modules.py | 40 ++-
python/tvm/relax/frontend/nn/op.py | 285 +++++++++++++++++++--
python/tvm/relax/op/image/image.py | 2 +-
src/relax/transform/rewrite_dataflow_reshape.cc | 9 +-
tests/python/relax/test_frontend_nn_modules.py | 20 +-
tests/python/relax/test_frontend_nn_op.py | 119 +++++++--
tests/python/relax/test_op_image.py | 2 -
.../test_transform_rewrite_dataflow_reshape.py | 54 ++++
8 files changed, 479 insertions(+), 52 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/modules.py
b/python/tvm/relax/frontend/nn/modules.py
index 6df0957398..3eb3c569dc 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -78,6 +78,24 @@ class SiLU(Module):
return op.silu(x)
+class Identity(Module):
+ """Module that does nothing, sometimes useful for naming purposes."""
+
+ def forward(self, x: Tensor):
+ """Forward method for identity.
+
+ Parameters
+ ----------
+ x : Tensor
+ The input tensor.
+ Returns
+ -------
+ Result : Tensor
+ The unchanged input tensor.
+ """
+ return x
+
+
class Linear(Module):
"""
Module for linear layer.
@@ -250,15 +268,20 @@ class LayerNorm(Module):
def __init__(
self,
normalized_shape: int,
- axes: Union[int, List[int]],
eps: Optional[float] = 1e-5,
+ elementwise_affine: bool = True,
dtype: Optional[str] = None,
) -> None:
super().__init__()
+ self.normalized_shape = normalized_shape
self.eps = eps
- self.axes = axes
- self.weight = Parameter((normalized_shape,), dtype=dtype)
- self.bias = Parameter((normalized_shape,), dtype=dtype)
+ self.elementwise_affine = elementwise_affine
+ if self.elementwise_affine:
+ self.weight = Parameter((normalized_shape,), dtype=dtype)
+ self.bias = Parameter((normalized_shape,), dtype=dtype)
+ else:
+ self.weight = None
+ self.bias = None
def forward(self, x: Tensor) -> Tensor:
"""
@@ -274,8 +297,13 @@ class LayerNorm(Module):
ret : Tensor
The output tensor for the layer normalization layer.
"""
- out = op.layer_norm(x, weight=self.weight, bias=self.bias,
axes=self.axes, epsilon=self.eps)
- return out
+ return op.layer_norm(
+ x,
+ normalized_shape=self.normalized_shape,
+ weight=self.weight,
+ bias=self.bias,
+ eps=self.eps,
+ )
class RMSNorm(Module):
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 4ef02797c2..5473fcb499 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -14,18 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=too-many-lines,invalid-name,protected-access
+# pylint:
disable=too-many-lines,invalid-name,protected-access,redefined-outer-name
"""nn.Tensor operators."""
import math
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
+import numpy as np
from tvm import tir as _tir
from ... import expr as rx
from ... import op as _op
from ...block_builder import BlockBuilder
from ...struct_info import TensorStructInfo, TupleStructInfo
-from .core import Tensor
+from .core import Tensor, get_default_dtype
from .spec import SpecBuilder
IntExpr = Union[int, _tir.PrimExpr]
@@ -57,11 +58,53 @@ def _wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor,
Tuple[Tensor]]:
rx.TupleGetItem(expr, i),
name=f"{name}.{i}",
)
- for i in range(expr.struct_info_.fields)
+ for i in range(len(expr.struct_info_.fields))
)
raise TypeError(f"Unsupported return type: {expr.struct_info_}")
+def unsqueeze(x: Tensor, dim: int, name: str = "unsqueeze") -> Tensor:
+ """Add a new axis to a tensor
+
+ Parameters
+ ----------
+ x : Tensor
+ Input tensor to expand.
+ dim : int
+ Dimension to expand.
+ name : str
+ Name hint for this operator.
+
+ Returns
+ -------
+ result : Tensor
+ Expanded result.
+ """
+ return _wrap_nested(_op.expand_dims(x._expr, dim), name)
+
+
+def concat(x: List[Tensor], dim: int, name: str = "concat") -> Tensor:
+ """Concatenate a list of tensors along an axis.
+
+ Parameters
+ ----------
+ x : List[Tensor]
+ List of tensors to concatenate.
+ dim : int
+ Dimension to concatenate upon.
+ name : str
+ Name hint for this operator.
+
+ Returns
+ -------
+ result : Tensor
+ Expanded result.
+ """
+ # Convert tensors to expressions.
+ x = [t._expr for t in x]
+ return _wrap_nested(_op.concat(x, dim), name)
+
+
def add(a: Tensor, b: Tensor, name: str = "add") -> Tensor:
"""Addition with numpy-style broadcasting.
@@ -90,6 +133,34 @@ def add(a: Tensor, b: Tensor, name: str = "add") -> Tensor:
return _wrap_nested(_op.add(a._expr, b._expr), name)
+def subtract(a: Tensor, b: Tensor, name: str = "subtract") -> Tensor:
+ """Subtraction with numpy-style broadcasting.
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+
+ Examples
+ --------
+ .. code:: python
+
+ c = subtract(a, b)
+ """
+ return _wrap_nested(_op.subtract(a._expr, b._expr), name)
+
+
def multiply(a: Tensor, b: Tensor, name: str = "mul") -> Tensor:
"""Multiplication with numpy-style broadcasting.
@@ -146,6 +217,28 @@ def divide(a: Tensor, b: Tensor, name: str = "divide") ->
Tensor:
return _wrap_nested(_op.divide(a._expr, b._expr), name)
+def chunk(x: Tensor, chunks: int, dim: int = 0, name: str = "chunk") -> Tensor:
+ """Split a tensor along dim into the specified number of chunks.
+
+ Parameters
+ ----------
+ x : Tensor
+ Input tensor to be split.
+ chunks : int
+ Number of pieces to slice x into.
+ dim : int
+ Which dimension to split x.
+ name : str
+ Name hint for this operation.
+
+ Returns
+ -------
+ result : Tuple[Tensor]
+ A tuple with chunks elements containing slices of x.
+ """
+ return _wrap_nested(_op.split(x._expr, chunks, dim), name)
+
+
def matmul(a: Tensor, b: Tensor, out_dtype: Optional[str] = None, name: str =
"matmul") -> Tensor:
"""General matrix multiplication of two tensors, with broadcasting on
batched dimensions.
@@ -569,6 +662,9 @@ def astype(x: Tensor, dtype: str, name: str = "astype") ->
Tensor:
result : Tensor
The casted result.
"""
+ # If trying to cast to same dtype as x, skip casting.
+ if x.dtype == dtype:
+ return x
return _wrap_nested(_op.astype(x._expr, dtype), name)
@@ -598,7 +694,7 @@ def silu(x: Tensor, name: str = "silu") -> Tensor:
return _wrap_nested(_op.nn.silu(x._expr), name)
-def gelu(x: Tensor, name: str = "gelu") -> Tensor:
+def gelu(x: Tensor, approximate: Optional[str] = None, name: str = "gelu") ->
Tensor:
r"""Applies the Gaussian Error Linear Units function
.. math::
@@ -611,7 +707,10 @@ def gelu(x: Tensor, name: str = "gelu") -> Tensor:
x : Tensor
The input data
- naem : str
+ approximate : Optional[str]
+ If set to tanh, use an approximation when calculating CDF.
+
+ name : str
Name hint.
Returns
@@ -623,7 +722,20 @@ def gelu(x: Tensor, name: str = "gelu") -> Tensor:
----
The input tensor is required to have float dtype
"""
- return _wrap_nested(_op.nn.gelu(x._expr), name)
+ dtype = x._expr.struct_info.dtype
+ if approximate == "tanh":
+ tanh_const = rx.const(1 + np.tanh(np.sqrt(2 / np.pi)), dtype=dtype)
+ gelu_out = (
+ rx.const(0.5, dtype)
+ * x._expr
+ * (
+ tanh_const
+ * (x._expr + (rx.const(0.044715, dtype) * _op.power(x._expr,
rx.const(3, "int32"))))
+ )
+ )
+ else:
+ gelu_out = _op.nn.gelu(x._expr)
+ return _wrap_nested(gelu_out, name)
def softmax(x: Tensor, axis: int = -1, name: str = "softmax") -> Tensor:
@@ -658,10 +770,10 @@ def softmax(x: Tensor, axis: int = -1, name: str =
"softmax") -> Tensor:
def layer_norm(
x: Tensor,
- weight: Tensor,
- bias: Tensor,
- axes: Union[int, List[int]],
- epsilon: float = 1e-5,
+ normalized_shape: Union[int, List[int]],
+ weight: Optional[Tensor] = None,
+ bias: Optional[Tensor] = None,
+ eps: float = 1e-5,
name: str = "layer_norm",
) -> Tensor:
r"""
@@ -685,19 +797,21 @@ def layer_norm(
Parameters
----------
- data : Tensor
+ x : Tensor
Input to which layer_norm will be applied.
- gamma : Tensor
+ normalized_shape: Union[int, List[int]]
+ The shape of axes to normalize. If a single integer
+ is used, it is treated as a singleton list and this
+ module will normalize over the last dimension.
+
+ weight: Tensor
The gamma scale factor.
- beta : Tensor
+ bias: Tensor
The beta offset factor.
- axes : Union[int, List[int]]
- The axes that along which the normalization is applied.
-
- epsilon : float
+ eps: float
Small float added to variance to avoid dividing by zero.
name : str
@@ -708,7 +822,31 @@ def layer_norm(
result : Tensor
The computed result.
"""
- return _wrap_nested(_op.nn.layer_norm(x._expr, weight._expr, bias._expr,
axes, epsilon), name)
+ if isinstance(normalized_shape, int):
+ normalized_shape = [normalized_shape]
+ dim_num = len(normalized_shape)
+ axes = list(range(-dim_num, 0))
+ dtype = x._expr.struct_info.dtype
+
+ if weight is not None:
+ weight = weight._expr
+ else:
+ weight = rx.const(np.ones(normalized_shape), dtype=dtype)
+ if bias is not None:
+ bias = bias._expr
+ else:
+ bias = rx.const(np.zeros(normalized_shape), dtype=dtype)
+
+ return _wrap_nested(
+ _op.nn.layer_norm(
+ x._expr,
+ gamma=weight,
+ beta=bias,
+ axes=axes,
+ epsilon=eps,
+ ),
+ name=name,
+ )
def rms_norm(
@@ -906,6 +1044,39 @@ def zeros(
return _wrap_nested(_op.zeros(shape, dtype), name)
+def pad(
+ x: Tensor,
+ pad: List[int],
+ mode: str = "constant",
+ value: int = 0,
+ name: str = "pad",
+) -> Tensor:
+ """
+ Apply spatial padding to the input tensor.
+
+ Parameters
+ ----------
+ x : Tensor
+ Input tensor to be padded.
+ pad : List[int]
+ List in the format of [before_0, after_0, before_1, after_1, ...]
+ indicating how much to pad each axis of x.
+ mod : str
+ Padding mode to use, constant implies padded elements will use
+ value argument.
+ value : int
+ What to pad with in constant mode.
+ name : str
+ Name hint for this operator.
+
+ Returns
+ -------
+ result : Tensor
+ Padded output tensor.
+ """
+ return _wrap_nested(_op.nn.pad(x._expr, pad_width=pad, pad_value=value,
pad_mode=mode), name)
+
+
def get_timestep_embedding(
x: Tensor,
embedding_dim: int,
@@ -940,19 +1111,20 @@ def get_timestep_embedding(
result : Tensor
[N x dim] Tensor of positional embeddings.
"""
- timesteps = _op.astype(x._expr, "float32")
+ dtype = get_default_dtype()
+ timesteps = _op.astype(x._expr, dtype)
half_dim = embedding_dim // 2
- exponent = rx.const(-math.log(max_period), "float32") * _op.arange(
- start=0, end=half_dim, dtype="float32"
+ exponent = rx.const(-math.log(max_period), dtype) * _op.arange(
+ start=0, end=half_dim, dtype=dtype
)
- exponent = exponent / (rx.const(half_dim - downscale_freq_shift,
"float32"))
+ exponent = exponent / (rx.const(half_dim - downscale_freq_shift, dtype))
emb = _op.exp(exponent)
emb = _op.expand_dims(timesteps, 1) * _op.expand_dims(emb, 0)
# Scale embeddings
if scale != 1:
- emb = rx.const(scale, "float32") * emb
+ emb = rx.const(scale, dtype) * emb
# Concat sine and cosine embeddings.
if flip_sin_to_cos:
@@ -1005,6 +1177,73 @@ def scaled_dot_product_attention(
return _wrap_nested(attn, name)
+def interpolate(
+ x: Tensor,
+ size: Optional[Union[int, Tuple[int]]] = None,
+ scale_factor: Optional[Union[float, Tuple[float]]] = None,
+ mode: str = "nearest",
+ align_corners: Optional[bool] = None,
+ recompute_scale_factor: Optional[bool] = None,
+ antialias: Optional[bool] = None,
+ name: str = "interpolate",
+):
+ """Resize a tensor using the specified mode.
+
+ Parameters
+ ----------
+ x : Tensor
+ Input tensor to be resized.
+ size : Optional[Union[int, Tuple[int]]]
+ Requested output size, only one of size and scale_factor may
+ be specified.
+ scale_factor : Optional[Union[float, Tuple[float]]]
+ Multiplier for spatial size.
+ mode : str
+ Algorithm used for sampling.
+ align_corners : Optional[bool]
+ How to map pixels before and after sampling.
+ recompute_scale_factor : Optional[bool]
+ Recompute the scale_factor for use in interpolation.
+ antialias : Optional[bool]
+ Apply antialiasing to output.
+ name : str
+ Name hint for this operation.
+
+ Returns
+ -------
+ result : Tensor
+ Output tensor with requested shape.
+ """
+ assert recompute_scale_factor is None, "recompute_scale_factor is not
supported."
+ assert antialias is None, "antialias is not supported."
+
+ if size is None:
+ shape = x.shape
+ if isinstance(scale_factor, (list, tuple)):
+ size = tuple(int(shape[i] * scale_factor[i]) for i in range(2,
len(shape)))
+ else:
+ size = tuple(int(shape[i] * scale_factor) for i in range(2,
len(shape)))
+
+ if mode.startswith("nearest"):
+ mode = "nearest_neighbor"
+ elif mode[0:2] == "bi":
+ mode = mode[2:]
+
+ if mode == "nearest_neighbor":
+ coord_trans = "asymmetric"
+ elif align_corners:
+ coord_trans = "align_corners"
+ else:
+ coord_trans = "half_pixel"
+
+ return _wrap_nested(
+ _op.image.resize2d(
+ x._expr, size, layout="NCHW", method=mode,
coordinate_transformation_mode=coord_trans
+ ),
+ name,
+ )
+
+
def tensor_expr_op(
tensor_expr_func: Callable,
name_hint: str,
diff --git a/python/tvm/relax/op/image/image.py
b/python/tvm/relax/op/image/image.py
index 562de5021d..e314e9b49a 100644
--- a/python/tvm/relax/op/image/image.py
+++ b/python/tvm/relax/op/image/image.py
@@ -107,7 +107,7 @@ def resize2d(
if isinstance(size, (int, PrimExpr)):
size = (size, size)
- if isinstance(size, tuple):
+ if isinstance(size, (tuple, list)):
if len(size) == 1:
size = ShapeExpr([size[0], size[0]])
else:
diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc
b/src/relax/transform/rewrite_dataflow_reshape.cc
index 2053a56bb7..8345f3e0b7 100644
--- a/src/relax/transform/rewrite_dataflow_reshape.cc
+++ b/src/relax/transform/rewrite_dataflow_reshape.cc
@@ -127,8 +127,13 @@ class DataflowReshapeRewriter : public ExprMutator {
return false;
}
auto product = [](Array<PrimExpr> args) -> PrimExpr {
- ICHECK(!args.empty());
- PrimExpr p = args[0];
+ PrimExpr p;
+ if (args.empty()) {
+ // Scalar tensors may be empty indicating a single element.
+ p = 1;
+ } else {
+ p = args[0];
+ }
for (int i = 1, e = args.size(); i < e; ++i) p *= args[i];
return p;
};
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index dba4178f65..86f018383f 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -43,6 +43,22 @@ def test_silu():
assert_structural_equal(tvm_mod["forward"], forward, True)
+def test_identity():
+ @R.function
+ def forward(
+ x: R.Tensor((3, 3), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
+ with R.dataflow():
+ gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object))
= x, (_io,)
+ R.output(gv1)
+ return gv1
+
+ mod = modules.Identity()
+ tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3),
"float32")}})
+ assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
def test_linear():
@R.function
def forward(
@@ -93,13 +109,13 @@ def test_layer_norm():
@R.function
def forward(x: R.Tensor((2, 4, 8), dtype="float32"), weight:
R.Tensor((8,), dtype="float32"), bias: R.Tensor((8,), dtype="float32"), _io:
R.Object) -> R.Tuple(R.Tensor((2, 4, 8), dtype="float32"), R.Tuple(R.Object)):
with R.dataflow():
- layer_norm: R.Tensor((2, 4, 8), dtype="float32") =
R.nn.layer_norm(x, weight, bias, axes=[2], epsilon=1.0000000000000001e-05,
center=True, scale=True)
+ layer_norm: R.Tensor((2, 4, 8), dtype="float32") =
R.nn.layer_norm(x, weight, bias, axes=[-1], epsilon=1.0000000000000001e-05,
center=True, scale=True)
gv1: R.Tuple(R.Tensor((2, 4, 8), dtype="float32"),
R.Tuple(R.Object)) = layer_norm, (_io,)
R.output(gv1)
return gv1
# fmt: on
- mod = modules.LayerNorm(8, [2])
+ mod = modules.LayerNorm(8)
tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((2, 4, 8),
"float32")}})
assert_structural_equal(tvm_mod["forward"], forward, True)
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index 6b59d0419a..255f8ab3ba 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -37,11 +37,12 @@ def test_binary():
z3 = op.matmul(x, y)
z4 = op.maximum(x, y)
z5 = op.minimum(x, y)
- return (z0, z1, z2, z3, z4, z5)
+ z6 = op.subtract(x, y)
+ return (z0, z1, z2, z3, z4, z5, z6)
# fmt: off
@R.function
- def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1),
dtype="float32"), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 10),
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10),
dtype="float32"), R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10),
dtype="float32"), R.Tensor((10, 10), dtype="float32")), R.Tuple(R.Object)):
+ def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1),
dtype="float32"), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 10),
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10),
dtype="float32"), R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10),
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10),
dtype="float32")), R.Tuple(R.Object)):
with R.dataflow():
add: R.Tensor((10, 10), dtype="float32") = R.add(x, y)
mul: R.Tensor((10, 10), dtype="float32") = R.multiply(x, y)
@@ -49,7 +50,8 @@ def test_binary():
matmul: R.Tensor((1, 1), dtype="float32") = R.matmul(x, y,
out_dtype="void")
maximum: R.Tensor((10, 10), dtype="float32") = R.maximum(x, y)
minimum: R.Tensor((10, 10), dtype="float32") = R.minimum(x, y)
- gv1: R.Tuple(R.Tuple(R.Tensor((10, 10), dtype="float32"),
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"),
R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10), dtype="float32"),
R.Tensor((10, 10), dtype="float32")), R.Tuple(R.Object)) = (add, mul, divide,
matmul, maximum, minimum), (_io,)
+ subtract: R.Tensor((10, 10), dtype="float32") = R.subtract(x, y)
+ gv1: R.Tuple(R.Tuple(R.Tensor((10, 10), dtype="float32"),
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"),
R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10), dtype="float32"),
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")),
R.Tuple(R.Object)) = (add, mul, divide, matmul, maximum, minimum, subtract),
(_io,)
R.output(gv1)
return gv1
# fmt: on
@@ -58,7 +60,6 @@ def test_binary():
irmodule, _ = m.export_tvm(
spec={"test": {"x": spec.Tensor([1, 10], "float32"), "y":
spec.Tensor([10, 1], "float32")}}
)
-
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -70,24 +71,28 @@ def test_manipulate():
z2 = op.reshape(x, [1, 10])
z3 = op.repeat(x, repeats=2, axis=1)
z4 = op.squeeze(x, 0)
- return (z0, z1, z2, z3, z4)
+ z5 = op.unsqueeze(x, 0)
+ z6 = op.concat([x, x], dim=0)
+ return (z0, z1, z2, z3, z4, z5, z6)
# fmt: off
@R.function
- def test(x: R.Tensor((1, 5, 2), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"), R.Tensor((2, 5, 1),
dtype="float32"), R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10, 2),
dtype="float32"), R.Tensor((5, 2), dtype="float32")), R.Tuple(R.Object)):
+ def test(x: R.Tensor((1, 5, 2), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"), R.Tensor((2, 5, 1),
dtype="float32"), R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10, 2),
dtype="float32"), R.Tensor((5, 2), dtype="float32"), R.Tensor((1, 1, 5, 2),
dtype="float32"), R.Tensor((2, 5, 2), dtype="float32")), R.Tuple(R.Object)):
with R.dataflow():
broadcast_to: R.Tensor((2, 5, 2), dtype="float32") =
R.broadcast_to(x, R.shape([2, 5, 2]))
permute_dims: R.Tensor((2, 5, 1), dtype="float32") =
R.permute_dims(x, axes=[2, 1, 0])
reshape: R.Tensor((1, 10), dtype="float32") = R.reshape(x,
R.shape([1, 10]))
repeat: R.Tensor((1, 10, 2), dtype="float32") = R.repeat(x,
repeats=2, axis=1)
squeeze: R.Tensor((5, 2), dtype="float32") = R.squeeze(x, axis=[0])
- gv1: R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"),
R.Tensor((2, 5, 1), dtype="float32"), R.Tensor((1, 10), dtype="float32"),
R.Tensor((1, 10, 2), dtype="float32"), R.Tensor((5, 2), dtype="float32")),
R.Tuple(R.Object)) = (broadcast_to, permute_dims, reshape, repeat, squeeze),
(_io,)
+ unsqueeze: R.Tensor((1, 1, 5, 2), dtype="float32") =
R.expand_dims(x, axis=0)
+ concat: R.Tensor((2, 5, 2), dtype="float32") = R.concat([x, x],
axis=0)
+ gv1: R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"),
R.Tensor((2, 5, 1), dtype="float32"), R.Tensor((1, 10), dtype="float32"),
R.Tensor((1, 10, 2), dtype="float32"), R.Tensor((5, 2), dtype="float32"),
R.Tensor((1, 1, 5, 2), dtype="float32"), R.Tensor((2, 5, 2), dtype="float32")),
R.Tuple(R.Object)) = (broadcast_to, permute_dims, reshape, repeat, squeeze,
unsqueeze, concat), (_io,)
R.output(gv1)
return gv1
# fmt: on
m = Model()
- irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 5, 2],
"float32")}})
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 5, 2],
"float32")}})
tvm.ir.assert_structural_equal(irmodule["test"], test)
@@ -141,8 +146,10 @@ def test_datatype():
def test_image():
class Model(Module):
def test(self, x: Tensor, weight: Tensor, bias: Tensor):
- conv2d_out = op.conv2d(x, weight, bias)
- return conv2d_out
+ padded = op.pad(x, [0, 0, 0, 0, 1, 1, 1, 1])
+ conv2d = op.conv2d(padded, weight, bias)
+ interpolate = op.interpolate(x, size=[40, 40])
+ return (conv2d, interpolate)
@R.function
def test(
@@ -150,19 +157,53 @@ def test_image():
weight: R.Tensor((32, 3, 3, 3), dtype="float32"),
bias: R.Tensor((32,), dtype="float32"),
_io: R.Object,
- ) -> R.Tuple(R.Tensor((1, 32, 30, 30), dtype="float32"),
R.Tuple(R.Object)):
+ ) -> R.Tuple(
+ R.Tuple(
+ R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tensor((1, 3, 40,
40), dtype="float32")
+ ),
+ R.Tuple(R.Object),
+ ):
with R.dataflow():
- lv1: R.Tensor((1, 32, 30, 30), dtype="float32") = R.nn.conv2d(x,
weight)
+ lv0: R.Tensor((1, 3, 34, 34), dtype="float32") = R.nn.pad(x, (0,
0, 0, 0, 1, 1, 1, 1))
+ lv1: R.Tensor((1, 32, 32, 32), dtype="float32") = R.nn.conv2d(
+ lv0,
+ weight,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="void",
+ )
lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(bias,
R.shape([1, 32, 1, 1]))
- conv2d: R.Tensor((1, 32, 30, 30), dtype="float32") = R.add(lv1,
lv2)
- gv1: R.Tuple(R.Tensor((1, 32, 30, 30), dtype="float32"),
R.Tuple(R.Object)) = conv2d, (
- _io,
+ conv2d: R.Tensor((1, 32, 32, 32), dtype="float32") = R.add(lv1,
lv2)
+ interpolate: R.Tensor((1, 3, 40, 40), dtype="float32") =
R.image.resize2d(
+ x,
+ R.shape([40, 40]),
+ roi=[T.float32(0), T.float32(0), T.float32(0), T.float32(0)],
+ layout="NCHW",
+ method="nearest_neighbor",
+ coordinate_transformation_mode="asymmetric",
+ rounding_method="round",
+ cubic_alpha=-0.5,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="void",
)
+ gv1: R.Tuple(
+ R.Tuple(
+ R.Tensor((1, 32, 32, 32), dtype="float32"),
+ R.Tensor((1, 3, 40, 40), dtype="float32"),
+ ),
+ R.Tuple(R.Object),
+ ) = (conv2d, interpolate), (_io,)
R.output(gv1)
return gv1
m = Model()
- irmodule, params = m.export_tvm(
+ irmodule, _ = m.export_tvm(
spec={
"test": {
"x": spec.Tensor([1, 3, 32, 32], "float32"),
@@ -174,6 +215,52 @@ def test_image():
tvm.ir.assert_structural_equal(irmodule["test"], test)
+def test_chunk():
+ class Model(Module):
+ def test(self, x: Tensor):
+ chunk = op.chunk(x, chunks=4)
+ return chunk
+
+ @R.function
+ def test(
+ x: R.Tensor((8,), dtype="float32"), _io: R.Object
+ ) -> R.Tuple(
+ R.Tuple(
+ R.Tensor((2,), dtype="float32"),
+ R.Tensor((2,), dtype="float32"),
+ R.Tensor((2,), dtype="float32"),
+ R.Tensor((2,), dtype="float32"),
+ ),
+ R.Tuple(R.Object),
+ ):
+ with R.dataflow():
+ chunk: R.Tuple(
+ R.Tensor((2,), dtype="float32"),
+ R.Tensor((2,), dtype="float32"),
+ R.Tensor((2,), dtype="float32"),
+ R.Tensor((2,), dtype="float32"),
+ ) = R.split(x, indices_or_sections=4, axis=0)
+ chunk_0: R.Tensor((2,), dtype="float32") = chunk[0]
+ chunk_1: R.Tensor((2,), dtype="float32") = chunk[1]
+ chunk_2: R.Tensor((2,), dtype="float32") = chunk[2]
+ chunk_3: R.Tensor((2,), dtype="float32") = chunk[3]
+ gv1: R.Tuple(
+ R.Tuple(
+ R.Tensor((2,), dtype="float32"),
+ R.Tensor((2,), dtype="float32"),
+ R.Tensor((2,), dtype="float32"),
+ R.Tensor((2,), dtype="float32"),
+ ),
+ R.Tuple(R.Object),
+ ) = (chunk_0, chunk_1, chunk_2, chunk_3), (_io,)
+ R.output(gv1)
+ return gv1
+
+ m = Model()
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([8],
"float32")}})
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
def test_nn():
class Model(Module):
def test(self, x: Tensor, weight: Tensor, bias: Tensor):
diff --git a/tests/python/relax/test_op_image.py
b/tests/python/relax/test_op_image.py
index 251a30139b..3b2db54848 100644
--- a/tests/python/relax/test_op_image.py
+++ b/tests/python/relax/test_op_image.py
@@ -244,8 +244,6 @@ def test_resize2d_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.image.resize2d(x1, size=32))
with pytest.raises(TVMError):
bb.normalize(relax.op.image.resize2d(x2, s0))
- with pytest.raises(TVMError):
- relax.op.image.resize2d(x2, [30, 30])
if __name__ == "__main__":
diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
index b02663e5eb..01065bea21 100644
--- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
+++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py
@@ -446,5 +446,59 @@ def test_reshape_detect_nop():
tvm.ir.assert_structural_equal(rewritten, Module)
+def test_reshape_scalar():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,),
dtype="float32"):
+ with R.dataflow():
+ lv1: R.Tensor((1,), dtype="float32") = R.reshape(x, [1])
+ lv2: R.Tensor((1,), dtype="float32") = R.add(lv1, lv1)
+ R.output(lv2)
+ return lv2
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def add(
+ A: T.Buffer((T.int64(1),), "float32"),
+ B: T.Buffer((T.int64(1),), "float32"),
+ T_add: T.Buffer((T.int64(1),), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for ax0 in range(T.int64(1)):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ T.reads(A[v_ax0], B[v_ax0])
+ T.writes(T_add[v_ax0])
+ T_add[v_ax0] = A[v_ax0] + B[v_ax0]
+
+ @T.prim_func(private=True)
+ def reshape(A: T.Buffer((), "float32"), T_reshape:
T.Buffer((T.int64(1),), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for ax0 in range(T.int64(1)):
+ with T.block("T_reshape"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ T.reads(A[()])
+ T.writes(T_reshape[v_ax0])
+ T_reshape[v_ax0] = A[()]
+
+ @R.function
+ def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((1,),
dtype="float32"):
+ cls = Expected
+ with R.dataflow():
+ lv1: R.Tensor((1,), dtype="float32") = R.reshape(x,
R.shape([1]))
+ lv2 = R.call_tir(cls.add, (lv1, lv1), out_sinfo=R.Tensor((1,),
dtype="float32"))
+ R.output(lv2)
+ return lv2
+
+ mod = Module
+ mod = relax.transform.LegalizeOps()(mod)
+ rewritten = relax.transform.RewriteDataflowReshape()(mod)
+ tvm.ir.assert_structural_equal(rewritten, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()