This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new f922f7cca2 [Unity][Frontend][NN] gelu NN module (#15593)
f922f7cca2 is described below

commit f922f7cca29011db22b690d354dfa928d438101b
Author: Lesheng Jin <[email protected]>
AuthorDate: Sat Aug 19 10:13:59 2023 -0700

    [Unity][Frontend][NN] gelu NN module (#15593)
    
    * gelu
    
    * upd
---
 python/tvm/relax/frontend/nn/op.py        | 28 ++++++++++++++++++++++++++++
 tests/python/relax/test_frontend_nn_op.py |  2 ++
 2 files changed, 30 insertions(+)

diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index d4821d8185..06d5e883e5 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -455,6 +455,34 @@ def silu(x: Tensor, name: str = "silu") -> Tensor:
     return _wrap_nested(_op.nn.silu(x._expr), name)
 
 
+def gelu(x: Tensor, name: str = "gelu") -> Tensor:
+    r"""Applies the Gaussian Error Linear Units function
+
+    .. math::
+        \text{GeLU}(x) = 0.5 * x * (1 + \text{erf}(x * 0.5**0.5))
+
+    where :math:`erf` is the Gauss Error function.
+
+    Parameters
+    ----------
+    x : Tensor
+        The input data
+
+    naem : str
+        Name hint.
+
+    Returns
+    -------
+    result : Tensor
+        The computed result.
+
+    Note
+    ----
+    The input tensor is required to have float dtype
+    """
+    return _wrap_nested(_op.nn.gelu(x._expr), name)
+
+
 def softmax(x: Tensor, axis: int = -1, name: str = "softmax") -> Tensor:
     r"""Computes softmax.
 
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index 41a29fa629..c23ea8a437 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -139,6 +139,7 @@ def test_nn():
     class Model(Module):
         def test(self, x: Tensor, weight: Tensor, bias: Tensor):
             silu_out = op.silu(x)
+            gelu_out = op.gelu(x)
             softmax_out = op.softmax(x, axis=2)
             rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1])
             rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1])
@@ -154,6 +155,7 @@ def test_nn():
     ) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)):
         with R.dataflow():
             silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x)
+            gelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.gelu(x)
             softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, 
axis=2)
             rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm(
                 x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05

Reply via email to