This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new bac72697ac [Unity] nn.Module Op (#15418)
bac72697ac is described below
commit bac72697ac09db2798d4a31a9a9b68d79e0cf200
Author: Lesheng Jin <[email protected]>
AuthorDate: Thu Jul 27 21:59:00 2023 -0700
[Unity] nn.Module Op (#15418)
This PR introduces _TensorOp class and a collection of useful Ops, helping
users build their own models in a torch-like programming style. Each Op will
invoke the `_wrap_nested` function to automatically emit relax Op with
BlockBuilder.
Example:
```python
class Model(Module):
def test(self, x: Tensor, y: Tensor):
z0 = op.add(x, y)
z1 = x * y
return (z0, z1)
```
---
python/tvm/relax/frontend/nn/_tensor_op.py | 59 ++-
python/tvm/relax/frontend/nn/op.py | 671 ++++++++++++++++++++++++++
python/tvm/relax/frontend/nn/spec.py | 2 +-
tests/python/relax/test_frontend_nn_op.py | 258 ++++++++++
tests/python/relax/test_frontend_nn_tensor.py | 159 ++++++
5 files changed, 1146 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/_tensor_op.py
b/python/tvm/relax/frontend/nn/_tensor_op.py
index a5e4f9b0cb..a653c9fa29 100644
--- a/python/tvm/relax/frontend/nn/_tensor_op.py
+++ b/python/tvm/relax/frontend/nn/_tensor_op.py
@@ -15,7 +15,62 @@
# specific language governing permissions and limitations
# under the License.
"""Adding member operators to nn.Tensor."""
+from typing import Optional
+from tvm import tir
-class _TensorOp: # pylint: disable=too-few-public-methods
- pass
+
+def _op():
+ from tvm.relax.frontend.nn import op # pylint:
disable=import-outside-toplevel
+
+ return op
+
+
+def _convert_scalar(scalar, ref) -> "Tensor":
+ from .core import Tensor # pylint: disable=import-outside-toplevel
+
+ if isinstance(scalar, Tensor):
+ return scalar
+ if isinstance(scalar, (tir.FloatImm, tir.IntImm)):
+ return Tensor.from_scalar(scalar.value, dtype=ref.dtype)
+ if isinstance(scalar, (int, float)):
+ return Tensor.from_scalar(scalar, dtype=ref.dtype)
+ return scalar
+
+
+class _TensorOp:
+ def __add__(self, other):
+ other = _convert_scalar(other, self)
+ return _op().add(self, other)
+
+ def __radd__(self, other):
+ other = _convert_scalar(other, self)
+ return _op().add(self, other)
+
+ def __mul__(self, other):
+ other = _convert_scalar(other, self)
+ return _op().multiply(self, other)
+
+ def __truediv__(self, other):
+ other = _convert_scalar(other, self)
+ return _op().divide(self, other)
+
+ def astype(self, dtype):
+ return _op().astype(self, dtype)
+
+ def maximum(self, other):
+ other = _convert_scalar(other, self)
+ return _op().maximum(self, other)
+
+ def minimum(self, other):
+ other = _convert_scalar(other, self)
+ return _op().minimum(self, other)
+
+ def reshape(self, shape):
+ return _op().reshape(self, shape)
+
+ def permute_dims(self, axes):
+ return _op().permute_dims(self, axes)
+
+ def repeat(self, repeats: int, axis: Optional[int] = None):
+ return _op().repeat(self, repeats, axis)
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
new file mode 100644
index 0000000000..74a337d8a3
--- /dev/null
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -0,0 +1,671 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=too-many-lines,invalid-name,protected-access
+"""nn.Tensor operators."""
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
+
+from tvm import tir as _tir
+
+from ... import expr as rx
+from ... import op as _op
+from ...block_builder import BlockBuilder
+from ...struct_info import TensorStructInfo, TupleStructInfo
+from .core import Tensor
+
+IntExpr = Union[int, _tir.PrimExpr]
+
+
+def _wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Tuple[Tensor]]:
+ """Wrap the given relax.Expr, emit it using the current BlockBuilder,
+ and automatically handle nested cases if the expr represents a Tuple.
+
+ Parameters
+ ----------
+ expr : relax.Expr
+ The Expr to be wrapped.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Union[Tensor, Tuple[Tensor]]
+ The computed result.
+ """
+ expr = BlockBuilder.current().emit(expr, name)
+ if isinstance(expr.struct_info_, TensorStructInfo):
+ return Tensor(_expr=expr)
+ if isinstance(expr.struct_info_, TupleStructInfo):
+ return tuple(
+ _wrap_nested(
+ rx.TupleGetItem(expr, i),
+ name=f"{name}.{i}",
+ )
+ for i in range(expr.struct_info_.fields)
+ )
+ raise TypeError(f"Unsupported return type: {expr.struct_info_}")
+
+
+def add(a: Tensor, b: Tensor, name: str = "add") -> Tensor:
+ """Addition with numpy-style broadcasting.
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+
+ Examples
+ --------
+ .. code:: python
+
+ c = add(a, b)
+ """
+ return _wrap_nested(_op.add(a._expr, b._expr), name)
+
+
+def multiply(a: Tensor, b: Tensor, name: str = "mul") -> Tensor:
+ """Multiplication with numpy-style broadcasting.
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+
+ Examples
+ --------
+ .. code:: python
+
+ c = multiply(a, b)
+ """
+ return _wrap_nested(_op.multiply(a._expr, b._expr), name)
+
+
+def divide(a: Tensor, b: Tensor, name: str = "divide") -> Tensor:
+ """Division with numpy-style broadcasting.
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+
+ Examples
+ --------
+ .. code:: python
+
+ c = divide(a, b)
+ """
+ return _wrap_nested(_op.divide(a._expr, b._expr), name)
+
+
+def matmul(a: Tensor, b: Tensor, out_dtype: Optional[str] = None, name: str =
"matmul") -> Tensor:
+ """General matrix multiplication of two tensors, with broadcasting on
batched dimensions.
+
+ The semantics and output shape deduction rule is specified as
+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html.
+
+ Parameters
+ ----------
+ a : Tensor
+ The first input tensor.
+
+ b : Tensor
+ The second input tensor.
+
+ out_dtype: Optional[Union[str, DataType]]
+ The data type of the matmul result.
+ When it is not specified, the output dtype will be the the same as
input dtype.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+
+ Examples
+ --------
+ .. code:: python
+
+ c = matmul(a, b)
+ """
+ return _wrap_nested(_op.matmul(a._expr, b._expr, out_dtype=out_dtype),
name)
+
+
+def maximum(x1: Tensor, x2: Tensor, name: str = "maximum"):
+ """Element-wise maximum
+
+ Parameters
+ ----------
+ x1 : Tensor
+ The first input tensor.
+
+ x2 : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+
+ Examples
+ --------
+ .. code:: python
+
+ c = maximum(a, b)
+ """
+ return _wrap_nested(_op.maximum(x1._expr, x2._expr), name)
+
+
+def minimum(x1: Tensor, x2: Tensor, name: str = "minimum"):
+ """Element-wise minimum
+
+ Parameters
+ ----------
+ x1 : Tensor
+ The first input tensor.
+
+ x2 : Tensor
+ The second input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+
+ Examples
+ --------
+ .. code:: python
+
+ c = minimum(a, b)
+ """
+ return _wrap_nested(_op.minimum(x1._expr, x2._expr), name)
+
+
+def broadcast_to(x: Tensor, shape: Sequence[IntExpr], name: str =
"broadcast_to") -> Tensor:
+ """Broadcasts a tensor to a specified shape.
+
+ Parameters
+ ----------
+ x : Tensor
+ The input data to the operator.
+
+ shape : Sequence[IntExpr]
+ The target shape.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The broadcasted tensor.
+ """
+ return _wrap_nested(_op.broadcast_to(x._expr, shape), name)
+
+
+def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str =
"permute_dims") -> Tensor:
+ """Permutes the dimensions of an array.
+
+ Parameters
+ ----------
+ x : Tensor
+ The input data to the operator.
+
+ axes : Optional[List[int]]
+ The target axes order, reverse order if not specified.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The transposed result.
+ """
+ return _wrap_nested(_op.permute_dims(x._expr, axes=axes), name)
+
+
+def reshape(x: Tensor, shape: Sequence[IntExpr], name="reshape") -> Tensor:
+ """Reshape the input array.
+
+ ``-1`` infers the dimension of the output shape by using the remainder of
+ the input dimensions keeping the size of the new array same as that of the
input array.
+ At most one dimension of shape can be -1.
+
+ .. code-block:: python
+
+ x.shape = (2, 3, 4), shape = (6, 1, -1), result.shape = (6, 1, 4)
+ x.shape = (2, 3, 4), shape = (3, -1, 8), result.shape = (3, 1, 8)
+ x.shape = (2, 3, 4), shape = (-1,), result.shape = (24,)
+
+ Parameters
+ ----------
+ x : Tensor
+ The input data to the operator.
+
+ shape : Sequence[IntExpr]
+ The new shape. Should be compatible with the original shape.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The reshaped result.
+
+ Note
+ ----
+ The ``-1`` inference is only performed at compile-time.
+ That is to say, in any case the dimension length of ``-1`` cannot be
inferred in
+ compile-time, an error will be thrown.
+ """
+ return _wrap_nested(_op.reshape(x._expr, shape), name)
+
+
+def repeat(x: Tensor, repeats: int, axis: Optional[int] = None, name="repeat")
-> Tensor:
+ """Repeats elements of an array.
+
+ Parameters
+ ----------
+ data : Tensor
+ The input tensor.
+
+ repeats : int
+ The number of repetitions.
+
+ axis: Optional[int]
+ The axis along which to repeat values. The negative numbers are
interpreted
+ counting from the backward. By default, use the flattened input array,
and
+ return a flat output array.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ ret : Tensor
+ The computed result.
+
+ Examples
+ --------
+ .. code-block:: python
+ np_x = numpy.array([[1, 2], [3, 4]])
+ x = Tensor.from_const(np_x)
+ lv1 = repeat(x, repeats=2) # lv1 == [1, 1, 2, 2, 3, 3, 4, 4]
+ lv2 = repeat(x, repeats=2, axis=1) # lv2 == [[1., 1., 2., 2.],
+ # [3., 3., 4., 4.]]
+ """
+ return _wrap_nested(_op.repeat(x._expr, repeats, axis), name)
+
+
+def squeeze(x: Tensor, axis: int = -1, name: str = "squeeze") -> Tensor:
+ """Squeeze axes in the array.
+
+ Parameters
+ ----------
+ x : Tensor
+ The input data to the operator.
+
+ axis : Optional[Union[int, List[int]]
+ The set of axes to remove.
+ If axis = None, remove all axis of dimensions 1.
+ If any specified axis has dimension that does not equal 1, it is an
error.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The squeezed result.
+ """
+ return _wrap_nested(_op.squeeze(x._expr, axis), name)
+
+
+def take(x: Tensor, indices: Tensor, axis: Optional[int] = None, name="take")
-> Tensor:
+ """Take elements from a tensor along an axis.
+ Its semantic is mostly similar to `numpy.take`
+ (https://numpy.org/doc/stable/reference/generated/numpy.take.html),
+ which can cover `torch.take`
(https://pytorch.org/docs/stable/generated/torch.take.html) and
+ `onnx.gather`
(https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13).
+
+ Parameters
+ ----------
+ x : Tensor
+ The source tensor.
+
+ indices : Tensor
+ The indices of the values to extract.
+
+ axis : Optional[int]
+ The axis over which to select values.
+ If it is none, the input tensor is required to be one-dimensional.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ ret : Tensor
+ The taken result.
+ """
+ return _wrap_nested(_op.take(x._expr, indices._expr, axis), name)
+
+
+def astype(x: Tensor, dtype: str, name: str = "astype") -> Tensor:
+ """Cast input tensor to the given data type.
+
+ Parameters
+ ----------
+ x : Tensor
+ The input data to the operator.
+
+ dtype: str
+ The target data type
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The casted result.
+ """
+ return _wrap_nested(_op.astype(x._expr, dtype), name)
+
+
+def silu(x: Tensor, name: str = "silu") -> Tensor:
+ r"""Sigmoid Linear Unit function
+
+ .. math::
+ \text{SiLU}(x) = x * \text{sigmoid}(x)
+
+ Parameters
+ ----------
+ data : Tensor
+ The input data
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+
+ Note
+ ----
+ The input tensor is required to have float dtype
+ """
+ return _wrap_nested(_op.nn.silu(x._expr), name)
+
+
+def softmax(x: Tensor, axis: int = -1, name: str = "softmax") -> Tensor:
+ r"""Computes softmax.
+
+ .. math:: \text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
+
+ Parameters
+ ----------
+ data: Tensor
+ The input data to the operator.
+
+ axis: int
+ The axis to sum over when computing softmax.
+ If not specified, it is by default the last axis of the input tensor.
+ Supports negative indexing.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+
+ Note
+ ----
+ The input tensor is required to have float dtype
+ """
+ return _wrap_nested(_op.nn.softmax(x._expr, axis), name)
+
+
+def rms_norm(
+ x: Tensor,
+ weight: Tensor,
+ bias: Optional[Tensor],
+ axes: Union[int, List[int]],
+ epsilon: float = 1e-5,
+ name: str = "rms_norm",
+) -> Tensor:
+ r"""
+ Root mean square normalization (Biao Zhang and et al., 2019).
+ Applies root mean square normalization to the n-dimensional input array.
+ This operator takes an n-dimensional input array and normalizes
+ the input using the given axis:
+
+ .. math::
+
+ out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight + bias
+
+ Parameters
+ ----------
+ data : Tensor
+ Input to which rms_norm will be applied.
+
+ weight : Tensor
+ The scale factor.
+
+ bias : Tensor
+ Optional offset factor.
+
+ axes : Union[int, List[int]]
+ The axes that along which the normalization is applied.
+
+ epsilon : float
+ Small float added to square mean to avoid dividing by zero.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ """
+ if bias is None:
+ bias = _op.zeros(weight.shape, dtype=weight.dtype)
+ else:
+ bias = bias._expr
+ return _wrap_nested(_op.nn.rms_norm(x._expr, weight._expr, bias, axes,
epsilon), name)
+
+
+def triu(x: Tensor, diagonal: int = 0, name: str = "triu") -> Tensor:
+ """Return the upper triangular part of a matrix or a batch of matrices.
+
+ Parameters
+ ----------
+ x : Tensor
+ The tensor that triu will be applied to.
+ It is required to have at least two dimensions.
+
+ k : int
+ The index indicating the diagonal below which to zero elements.
+ If k = 0, the diagonal is the main diagonal.
+ If k < 0, the diagonal is below the main diagonal.
+ If k > 0, the diagonal is above the main diagonal.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ ret : Tensor
+ The result tensor.
+ """
+ return _wrap_nested(_op.triu(x._expr, diagonal), name)
+
+
+def full(
+ shape: Sequence[IntExpr],
+ fill_value: Tensor,
+ dtype: str = "float32",
+ name: str = "full",
+) -> Tensor:
+ """Fill array with scalar value.
+
+ Parameters
+ ----------
+ shape : Sequence[IntExpr]
+ The shape of the created tensor.
+
+ fill_value : Tensor
+ The value to fill. Must be a scalar tensor.
+
+ dtype : str
+ The data type of the created tensor.
+ If dtype is not given, it will by default use the dtype of fill_value.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The result tensor.
+ """
+ from tvm import relax # pylint: disable=import-outside-toplevel
+
+ if isinstance(fill_value, (_tir.FloatImm, _tir.IntImm)):
+ fill_value = relax.const(fill_value.value, dtype=dtype)
+ elif isinstance(fill_value, (int, float)):
+ fill_value = relax.const(fill_value, dtype=dtype)
+ else:
+ fill_value = fill_value._expr
+ return _wrap_nested(_op.full(shape, fill_value, dtype), name)
+
+
+def zeros(
+ shape: Sequence[IntExpr],
+ dtype: str = "float32",
+ name: str = "zeros",
+) -> Tensor:
+ """Construct a tensor of all zeros, with the input shape and dtype.
+
+ Parameters
+ ----------
+ shape : Sequence[IntExpr]
+ The shape of the created tensor.
+
+ dtype : str
+ The data type of the created tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The result tensor.
+ """
+ return _wrap_nested(_op.zeros(shape, dtype), name)
+
+
+def tensor_expr_op(
+ tensor_expr_func: Callable,
+ name_hint: str,
+ args: List[Union[Tensor, _tir.Var]],
+ *,
+ attrs: Optional[Dict[str, Any]] = None,
+):
+ """Build the given tensor_expr_func with te.
+
+ Parameters
+ ----------
+ tensor_expr_func : Callable
+ A function that returns a te tensor or a list of tensors.
+
+ name_hint : str
+ Name hint.
+
+ args: List[Union[Tensor, _tir.Var]]
+ Arguments passed to the function.
+
+ attrs: Optional[Dict[str, Any]]
+ A dict of attributes to apply to the function.
+
+ Returns
+ -------
+ result : Tensor
+ The result tensor.
+ """
+
+ def _convert(arg):
+ if isinstance(arg, Tensor):
+ return arg._expr # pylint: disable=protected-access
+ return arg
+
+ return _wrap_nested(
+ BlockBuilder.current().emit_te(
+ tensor_expr_func,
+ *[_convert(arg) for arg in args],
+ primfunc_name_hint=name_hint,
+ primfunc_attrs=attrs,
+ ),
+ name=name_hint,
+ )
diff --git a/python/tvm/relax/frontend/nn/spec.py
b/python/tvm/relax/frontend/nn/spec.py
index 3a9be83a51..73d7c80638 100644
--- a/python/tvm/relax/frontend/nn/spec.py
+++ b/python/tvm/relax/frontend/nn/spec.py
@@ -106,7 +106,7 @@ class MethodSpec:
for arg_name in arg_names:
arg_spec = spec[arg_name]
if arg_spec is Int or arg_spec is int:
- arg_spec = arg_spec()
+ arg_spec = Int()
elif isinstance(arg_spec, str) and arg_spec == "int":
arg_spec = Int()
elif isinstance(arg_spec, (Int, Tensor)):
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
new file mode 100644
index 0000000000..6668f02743
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -0,0 +1,258 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, te, tir
+from tvm.relax.frontend.nn import Tensor, Module, spec, op
+from tvm.script import relax as R
+from tvm.script import tir as T
+from tvm.script import ir as I
+
+
+def test_binary():
+ class Model(Module):
+ def test(self, x: Tensor, y: Tensor):
+ z0 = op.add(x, y)
+ z1 = op.multiply(x, y)
+ z2 = op.divide(x, y)
+ z3 = op.matmul(x, y)
+ z4 = op.maximum(x, y)
+ z5 = op.minimum(x, y)
+ return (z0, z1, z2, z3, z4, z5)
+
+ # fmt: off
+ @R.function
+ def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1),
dtype="float32"), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((10, 10),
dtype="float32"), R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10),
dtype="float32"), R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10),
dtype="float32"), R.Tensor((10, 10), dtype="float32")), R.Tuple(R.Object)):
+ with R.dataflow():
+ add: R.Tensor((10, 10), dtype="float32") = R.add(x, y)
+ mul: R.Tensor((10, 10), dtype="float32") = R.multiply(x, y)
+ divide: R.Tensor((10, 10), dtype="float32") = R.divide(x, y)
+ matmul: R.Tensor((1, 1), dtype="float32") = R.matmul(x, y,
out_dtype="void")
+ maximum: R.Tensor((10, 10), dtype="float32") = R.maximum(x, y)
+ minimum: R.Tensor((10, 10), dtype="float32") = R.minimum(x, y)
+ gv1: R.Tuple(R.Tuple(R.Tensor((10, 10), dtype="float32"),
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32"),
R.Tensor((1, 1), dtype="float32"), R.Tensor((10, 10), dtype="float32"),
R.Tensor((10, 10), dtype="float32")), R.Tuple(R.Object)) = (add, mul, divide,
matmul, maximum, minimum), (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ irmodule, _ = m.export_tvm(
+ spec={"test": {"x": spec.Tensor([1, 10], "float32"), "y":
spec.Tensor([10, 1], "float32")}}
+ )
+
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
+def test_manipulate():
+ class Model(Module):
+ def test(self, x: Tensor):
+ z0 = op.broadcast_to(x, [2, 5, 2])
+ z1 = op.permute_dims(x, [2, 1, 0])
+ z2 = op.reshape(x, [1, 10])
+ z3 = op.repeat(x, repeats=2, axis=1)
+ z4 = op.squeeze(x, 0)
+ return (z0, z1, z2, z3, z4)
+
+ # fmt: off
+ @R.function
+ def test(x: R.Tensor((1, 5, 2), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"), R.Tensor((2, 5, 1),
dtype="float32"), R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10, 2),
dtype="float32"), R.Tensor((5, 2), dtype="float32")), R.Tuple(R.Object)):
+ with R.dataflow():
+ broadcast_to: R.Tensor((2, 5, 2), dtype="float32") =
R.broadcast_to(x, R.shape([2, 5, 2]))
+ permute_dims: R.Tensor((2, 5, 1), dtype="float32") =
R.permute_dims(x, axes=[2, 1, 0])
+ reshape: R.Tensor((1, 10), dtype="float32") = R.reshape(x,
R.shape([1, 10]))
+ repeat: R.Tensor((1, 10, 2), dtype="float32") = R.repeat(x,
repeats=2, axis=1)
+ squeeze: R.Tensor((5, 2), dtype="float32") = R.squeeze(x, axis=[0])
+ gv1: R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"),
R.Tensor((2, 5, 1), dtype="float32"), R.Tensor((1, 10), dtype="float32"),
R.Tensor((1, 10, 2), dtype="float32"), R.Tensor((5, 2), dtype="float32")),
R.Tuple(R.Object)) = (broadcast_to, permute_dims, reshape, repeat, squeeze),
(_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 5, 2],
"float32")}})
+
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
+def test_index():
+ class Model(Module):
+ def test(self, x: Tensor, y: Tensor):
+ z0 = op.take(x, y, axis=2)
+ return z0
+
+ # fmt: off
+ @R.function
+ def test(x: R.Tensor((2, 1, 10), dtype="float32"), y: R.Tensor((5,),
dtype="int32"), _io: R.Object) -> R.Tuple(R.Tensor((2, 1, 5), dtype="float32"),
R.Tuple(R.Object)):
+ with R.dataflow():
+ take: R.Tensor((2, 1, 5), dtype="float32") = R.take(x, y, axis=2)
+ gv1: R.Tuple(R.Tensor((2, 1, 5), dtype="float32"),
R.Tuple(R.Object)) = take, (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ irmodule, params = m.export_tvm(
+ spec={"test": {"x": spec.Tensor([2, 1, 10], "float32"), "y":
spec.Tensor([5], "int32")}}
+ )
+
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
+def test_datatype():
+ class Model(Module):
+ def test(self, x: Tensor):
+ z0 = op.astype(x, "float16")
+ return z0
+
+ # fmt: off
+ @R.function
+ def test(x: R.Tensor((2, 1, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tensor((2, 1, 10), dtype="float16"), R.Tuple(R.Object)):
+ with R.dataflow():
+ astype: R.Tensor((2, 1, 10), dtype="float16") = R.astype(x,
dtype="float16")
+ gv1: R.Tuple(R.Tensor((2, 1, 10), dtype="float16"),
R.Tuple(R.Object)) = astype, (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([2, 1,
10], "float32")}})
+
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
+def test_nn():
+ class Model(Module):
+ def test(self, x: Tensor, weight: Tensor, bias: Tensor):
+ silu_out = op.silu(x)
+ softmax_out = op.softmax(x, axis=2)
+ rms_norm_out = op.rms_norm(x, weight, bias, axes=[-2, -1])
+ rms_norm_with_bias_out = op.rms_norm(x, weight, bias, axes=[-2,
-1])
+ return x
+
+ # fmt: off
+ @R.function
+ def test(x: R.Tensor((2, 3, 4, 5), dtype="float32"), weight: R.Tensor((4,
5), dtype="float32"), bias: R.Tensor((4, 5), dtype="float32"), _io: R.Object)
-> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)):
+ with R.dataflow():
+ silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x)
+ softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x,
axis=2)
+ rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") =
R.nn.rms_norm(x, weight, bias, axes=[-2, -1], epsilon=1.0000000000000001e-05)
+ rms_norm1: R.Tensor((2, 3, 4, 5), dtype="float32") =
R.nn.rms_norm(x, weight, bias, axes=[-2, -1], epsilon=1.0000000000000001e-05)
+ gv1: R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"),
R.Tuple(R.Object)) = x, (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ irmodule, params = m.export_tvm(
+ spec={
+ "test": {
+ "x": spec.Tensor([2, 3, 4, 5], "float32"),
+ "weight": spec.Tensor([4, 5], "float32"),
+ "bias": spec.Tensor([4, 5], "float32"),
+ }
+ }
+ )
+
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
+def test_create():
+ class Model(Module):
+ def test(self, x: Tensor):
+ triu_out = op.triu(x)
+ full_with_scalar_out = op.full([10, 10], fill_value=10)
+ full_with_FloatImm_out = op.full(
+ [10, 10], fill_value=tir.FloatImm(dtype="float32", value=10)
+ )
+ full_with_Tensor_out = op.full(
+ [10, 10], fill_value=Tensor.from_scalar(10, dtype="float32")
+ )
+ zeros_out = op.zeros([10, 10])
+ zeros_fp16_out = op.zeros([10, 10], dtype="float16")
+ return x
+
+ # fmt: off
+ @R.function
+ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)):
+ with R.dataflow():
+ triu: R.Tensor((10, 10), dtype="float32") = R.triu(x, k=0)
+ full: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10,
10]), R.const(10, "float32"), dtype="float32")
+ full1: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10,
10]), R.const(10, "float32"), dtype="float32")
+ full2: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10,
10]), R.const(10, "float32"), dtype="float32")
+ zeros: R.Tensor((10, 10), dtype="float32") = R.zeros(R.shape([10,
10]), dtype="float32")
+ zeros1: R.Tensor((10, 10), dtype="float16") = R.zeros(R.shape([10,
10]), dtype="float16")
+ gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"),
R.Tuple(R.Object)) = x, (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10],
"float32")}})
+
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
+def test_tensor_expr_op():
+ class Model(Module):
+ def test(self, x: Tensor):
+ tensor_expr_op_out = op.tensor_expr_op(
+ tensor_expr_func=lambda x: x + 1, name_hint="add_one", args=[x]
+ )
+ return x
+
+ # fmt: off
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def add_one(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_add:
T.Buffer((T.int64(10), T.int64(10)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(10)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1])
+ T.writes(T_add[v_ax0, v_ax1])
+ T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + T.float32(1)
+
+ @R.function
+ def _initialize_effect() -> R.Tuple(R.Object):
+ with R.dataflow():
+ _io: R.Object = R.null_value()
+ lv: R.Tuple(R.Object) = (_io,)
+ gv: R.Tuple(R.Object) = lv
+ R.output(gv)
+ return gv
+
+ @R.function
+ def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)):
+ cls = Expected
+ with R.dataflow():
+ lv1 = R.call_tir(cls.add_one, (x,), out_sinfo=R.Tensor((10,
10), dtype="float32"))
+ add_one1: R.Tensor((10, 10), dtype="float32") = lv1
+ gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"),
R.Tuple(R.Object)) = x, (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10],
"float32")}})
+
+ tvm.ir.assert_structural_equal(irmodule, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_nn_tensor.py
b/tests/python/relax/test_frontend_nn_tensor.py
new file mode 100644
index 0000000000..63d756e637
--- /dev/null
+++ b/tests/python/relax/test_frontend_nn_tensor.py
@@ -0,0 +1,159 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.relax.frontend.nn import Tensor, Module, spec
+from tvm.script import relax as R
+
+import numpy as np
+
+
+def test_tensor_from_numpy():
+ x = np.random.rand(1, 10)
+ tensor_x = Tensor.from_const(x)
+ assert tensor_x.shape == [1, 10]
+ assert tensor_x.ndim == 2
+ assert tensor_x.dtype == "float32"
+ assert repr(tensor_x) == 'Tensor([1, 10], "float32")'
+
+
+def test_tensor_from_scalar():
+ x = 123.321
+ tensor_x = Tensor.from_scalar(x, dtype="float16")
+ assert tensor_x.shape == []
+ assert tensor_x.ndim == 0
+ assert tensor_x.dtype == "float16"
+ assert repr(tensor_x) == 'Tensor([], "float16")'
+
+
+def test_tensor_op_binary_tensor_tensor():
+ class Model(Module):
+ def test(self, x: Tensor, y: Tensor):
+ z0 = x + y
+ z1 = x * y
+ z2 = x / y
+ z3 = x.maximum(y)
+ z4 = x.minimum(y)
+ return (z0, z1, z2, z3, z4)
+
+ # fmt: off
+ @R.function
+ def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((2, 1),
dtype="float32"), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((2, 10),
dtype="float32"), R.Tensor((2, 10), dtype="float32"), R.Tensor((2, 10),
dtype="float32"), R.Tensor((2, 10), dtype="float32"), R.Tensor((2, 10),
dtype="float32")), R.Tuple(R.Object)):
+ with R.dataflow():
+ add: R.Tensor((2, 10), dtype="float32") = R.add(x, y)
+ mul: R.Tensor((2, 10), dtype="float32") = R.multiply(x, y)
+ divide: R.Tensor((2, 10), dtype="float32") = R.divide(x, y)
+ maximum: R.Tensor((2, 10), dtype="float32") = R.maximum(x, y)
+ minimum: R.Tensor((2, 10), dtype="float32") = R.minimum(x, y)
+ gv1: R.Tuple(R.Tuple(R.Tensor((2, 10), dtype="float32"),
R.Tensor((2, 10), dtype="float32"), R.Tensor((2, 10), dtype="float32"),
R.Tensor((2, 10), dtype="float32"), R.Tensor((2, 10), dtype="float32")),
R.Tuple(R.Object)) = (add, mul, divide, maximum, minimum), (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ irmodule, _ = m.export_tvm(
+ spec={"test": {"x": spec.Tensor([1, 10], "float32"), "y":
spec.Tensor([2, 1], "float32")}}
+ )
+
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
+def test_tensor_op_binary_tensor_saclar():
+ class Model(Module):
+ def test(self, x: Tensor):
+ y = 10
+ z0 = x + y
+ z1 = y + x
+ z2 = x * y
+ z3 = x / y
+ z4 = x.maximum(y)
+ z5 = x.minimum(y)
+ return (z0, z1, z2, z3, z4, z5)
+
+ # fmt: off
+ @R.function
+ def test(x: R.Tensor((1, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tuple(R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10),
dtype="float32"), R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10),
dtype="float32"), R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10),
dtype="float32")), R.Tuple(R.Object)):
+ with R.dataflow():
+ add: R.Tensor((1, 10), dtype="float32") = R.add(x, R.const(10,
"float32"))
+ add1: R.Tensor((1, 10), dtype="float32") = R.add(x, R.const(10,
"float32"))
+ mul: R.Tensor((1, 10), dtype="float32") = R.multiply(x,
R.const(10, "float32"))
+ divide: R.Tensor((1, 10), dtype="float32") = R.divide(x,
R.const(10, "float32"))
+ maximum: R.Tensor((1, 10), dtype="float32") = R.maximum(x,
R.const(10, "float32"))
+ minimum: R.Tensor((1, 10), dtype="float32") = R.minimum(x,
R.const(10, "float32"))
+ gv1: R.Tuple(R.Tuple(R.Tensor((1, 10), dtype="float32"),
R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10), dtype="float32"),
R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10), dtype="float32"),
R.Tensor((1, 10), dtype="float32")), R.Tuple(R.Object)) = (add, add1, mul,
divide, maximum, minimum), (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 10],
"float32")}})
+
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
+def test_tensor_op_datatype():
+ class Model(Module):
+ def test(self, x: Tensor):
+ z0 = x.astype(dtype="float16")
+ return z0
+
+ # fmt: off
+ @R.function
+ def test(x: R.Tensor((1, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tensor((1, 10), dtype="float16"), R.Tuple(R.Object)):
+ with R.dataflow():
+ astype: R.Tensor((1, 10), dtype="float16") = R.astype(x,
dtype="float16")
+ gv1: R.Tuple(R.Tensor((1, 10), dtype="float16"),
R.Tuple(R.Object)) = astype, (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 10],
"float32")}})
+
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
+def test_tensor_op_manipulate():
+ class Model(Module):
+ def test(self, x: Tensor):
+ z0 = x.reshape([2, 5, 2])
+ z1 = x.permute_dims([2, 1, 0])
+ z2 = x.repeat(2, axis=1)
+ return (z0, z1, z2)
+
+ # fmt: off
+ @R.function
+ def test(x: R.Tensor((2, 1, 10), dtype="float32"), _io: R.Object) ->
R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"), R.Tensor((10, 1, 2),
dtype="float32"), R.Tensor((2, 2, 10), dtype="float32")), R.Tuple(R.Object)):
+ with R.dataflow():
+ reshape: R.Tensor((2, 5, 2), dtype="float32") = R.reshape(x,
R.shape([2, 5, 2]))
+ permute_dims: R.Tensor((10, 1, 2), dtype="float32") =
R.permute_dims(x, axes=[2, 1, 0])
+ repeat: R.Tensor((2, 2, 10), dtype="float32") = R.repeat(x,
repeats=2, axis=1)
+ gv1: R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"),
R.Tensor((10, 1, 2), dtype="float32"), R.Tensor((2, 2, 10), dtype="float32")),
R.Tuple(R.Object)) = (reshape, permute_dims, repeat), (_io,)
+ R.output(gv1)
+ return gv1
+ # fmt: on
+
+ m = Model()
+ irmodule, params = m.export_tvm(spec={"test": {"x": spec.Tensor([2, 1,
10], "float32")}})
+
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()