This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5a67a00bcb [Unity][Frontend] Add Sqrt Op (#17228)
5a67a00bcb is described below
commit 5a67a00bcbb53731bbf53db7801fa16c8c9eb9f2
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Aug 5 21:17:48 2024 +0800
[Unity][Frontend] Add Sqrt Op (#17228)
* Update op.py
* Update test_frontend_nn_op.py
* Update op.py with annotation
* Update core.py(typo in annotation)
---
python/tvm/relax/frontend/nn/core.py | 2 +-
python/tvm/relax/frontend/nn/op.py | 22 ++++++++++++++++++++++
tests/python/relax/test_frontend_nn_op.py | 6 ++++--
3 files changed, 27 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index 3511c38a2b..21118b1cb8 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -17,7 +17,7 @@
"""The core infra for nn.Module, which includes the following pieces:
- Tensor, a wrapper on top of relax.Expr whose struct_info is a
TensorStructInfo,
providing more convenient access shape and dtype information.
- Tensor is always symbolc and not bound to any concrete values.
+ Tensor is always symbolic and not bound to any concrete values.
- Parameter, a special tensor which could be bound or not bound to concrete
values.
- Module, a container of nn.Parameters and sub nn.Modules.
- Effect, a non-user-facing class that encloses potential side effects, for
example, IO,
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index e1ba4483c7..17a40a8cce 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -1486,6 +1486,28 @@ def square(x: Tensor, name: str = "square") -> Tensor:
return wrap_nested(_op.square(x._expr), name)
+def sqrt(x: Tensor, name: str = "sqrt") -> Tensor:
+ """Computes the element-wise sqrt of the input tensor.
+
+ Parameters
+ ----------
+ x : Tensor
+ The input tensor.
+
+ name : str
+ Name hint.
+
+ Returns
+ -------
+ result : Tensor
+ The computed result.
+ Note
+ ----
+ The input tensor is required to have float dtype
+ """
+ return wrap_nested(_op.sqrt(x._expr), name)
+
+
def get_timestep_embedding(
x: Tensor,
embedding_dim: int,
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index a632a86743..6c32691954 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -31,7 +31,8 @@ def test_unary():
class Model(Module):
def test(self, x: Tensor):
z0 = op.square(x)
- return (x,)
+ z1 = op.sqrt(x)
+ return (z0, z1)
# fmt: off
@R.function
@@ -39,7 +40,8 @@ def test_unary():
R.func_attr({"num_input": 2})
with R.dataflow():
square: R.Tensor((1, 10), dtype="float32") = R.square(x)
- gv1 = (x,), (_io,)
+ sqrt: R.Tensor((1, 10), dtype="float32") = R.sqrt(x)
+ gv1 = (square, sqrt), (_io,)
R.output(gv1)
return gv1
# fmt: on