This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e4f16cb9e7 [Unity][Relax] gelu-tanh operator (#14814)
e4f16cb9e7 is described below

commit e4f16cb9e7c2a33010b71b716ed2fba9c6b42e86
Author: Zihao Ye <[email protected]>
AuthorDate: Thu May 18 06:47:52 2023 -0700

    [Unity][Relax] gelu-tanh operator (#14814)
---
 python/tvm/relax/backend/contrib/cutlass.py        |   2 +
 python/tvm/relax/frontend/torch/fx_translator.py   |  14 +-
 python/tvm/relax/op/nn/nn.py                       |  37 ++++-
 python/tvm/relax/transform/legalize_ops/nn.py      |  16 +++
 python/tvm/relax/transform/transform.py            |   2 +-
 src/relax/op/nn/nn.cc                              |   3 +
 src/relax/op/nn/nn.h                               |   3 +
 src/relax/transform/combine_parallel_matmul.cc     |   5 +-
 tests/python/relax/test_frontend_dynamo.py         |  53 +++++++
 .../python/relax/test_transform_legalize_ops_nn.py | 153 +++++++++++++++++++++
 10 files changed, 278 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 19fc2a39ea..72b8773977 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -160,6 +160,8 @@ def _check_matmul(context: PatternCheckContext) -> bool:
 def _get_activation_from_name(pattern_name):
     if "_relu" in pattern_name:
         return "relax.nn.relu"
+    elif "_gelu_tanh" in pattern_name:
+        return "relax.nn.gelu_tanh"
     elif "_gelu" in pattern_name:
         return "relax.nn.gelu"
     elif "_silu" in pattern_name:
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index a29070a325..c5d65e2f0d 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -234,6 +234,18 @@ class TorchFXImporter:
             )
         return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
 
+    def _gelu(self, node: fx.node.Node) -> relax.Expr:
+        if "approximate" not in node.kwargs:
+            approximate = "none"
+        else:
+            approximate = node.kwargs["approximate"]
+        if approximate == "none":
+            return 
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]]))
+        elif approximate == "tanh":
+            return 
self.block_builder.emit(relax.op.nn.gelu_tanh(self.env[node.args[0]]))
+        else:
+            raise KeyError("Unregonized approximate algorithm for gelu: 
{}.".format(approximate))
+
     ########## Compare ##########
 
     def _lt(self, node: fx.node.Node) -> relax.Expr:
@@ -1180,7 +1192,7 @@ class TorchFXImporter:
             "dropout": lambda node: self.env[node.args[0]],
             "clamp": self._clamp,
             "relu": lambda node: 
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
-            "gelu": lambda node: 
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])),
+            "gelu": self._gelu,
             "silu": lambda node: 
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
             "tanh": lambda node: 
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
             "interpolate": self._interpolate,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index fb5e0736ff..601a0c8439 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -532,10 +532,10 @@ def adaptive_avg_pool2d(
 
 
 def relu(data: Expr) -> Expr:
-    """Rectified linear unit.
+    r"""Rectified linear unit.
 
     .. math::
-        text{ReLU}(x) = max(x, 0)
+        \text{ReLU}(x) = \max(x, 0)
 
     Parameters
     ----------
@@ -551,10 +551,10 @@ def relu(data: Expr) -> Expr:
 
 
 def gelu(data: Expr) -> Expr:
-    """Gaussian Error Linear Units function
+    r"""Gaussian Error Linear Units function
 
     .. math::
-        text{GeLU}(x) = 0.5 * x * (1 + erf(x * 0.5**0.5))
+        \text{GeLU}(x) = 0.5 * x * (1 + \text{erf}(x * 0.5**0.5))
 
     where :math:`erf` is the Gauss Error function.
 
@@ -575,11 +575,34 @@ def gelu(data: Expr) -> Expr:
     return _ffi_api.gelu(data)  # type: ignore
 
 
+def gelu_tanh(data: Expr) -> Expr:
+    r"""Gaussian Error Linear Units function with tanh approximation
+
+    .. math::
+        \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 
0.044715 * x^3)))
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+
+    Note
+    ----
+    The input tensor is required to have float dtype
+    """
+    return _ffi_api.gelu_tanh(data)  # type: ignore
+
+
 def silu(data: Expr) -> Expr:
-    """Sigmoid Linear Unit function
+    r"""Sigmoid Linear Unit function
 
     .. math::
-        text{SiLU}(x) = x * sigmoid(x)
+        \text{SiLU}(x) = x * \text{sigmoid}(x)
 
     Parameters
     ----------
@@ -601,7 +624,7 @@ def silu(data: Expr) -> Expr:
 def softmax(data: Expr, axis: int = -1) -> Expr:
     r"""Computes softmax.
 
-    .. math:: text{softmax}(x)_i = frac{exp(x_i)}{\sum_j exp(x_j)}
+    .. math:: \text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}
 
     Parameters
     ----------
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 9c98682e32..9eea40b3fb 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -17,6 +17,7 @@
 # pylint: disable=invalid-name,unused-argument
 """Default legalization function for neural network operators."""
 import logging
+import math
 
 from tvm import topi, tir, te
 from ...block_builder import BlockBuilder
@@ -233,6 +234,21 @@ def _nn_gelu(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(te_gelu, call.args[0], primfunc_name_hint="gelu")
 
 
+@register_legalize("relax.nn.gelu_tanh")
+def _nn_gelu_tanh(bb: BlockBuilder, call: Call) -> Expr:
+    def te_gelu_tanh(x: te.Tensor):
+        dtype = x.dtype
+        return tir.const(0.5, dtype) * (
+            tir.const(1.0, dtype)
+            + topi.tanh(
+                tir.const(math.sqrt(2.0 / math.pi), dtype)
+                * (x + tir.const(0.044715, dtype) * topi.power(x, 3))
+            )
+        )
+
+    return bb.call_te(te_gelu_tanh, call.args[0], 
primfunc_name_hint="gelu_tanh")
+
+
 @register_legalize("relax.nn.silu")
 def _nn_silu(bb: BlockBuilder, call: Call) -> Expr:
     def te_silu(x: te.Tensor):
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 508e8bccba..a516c4e8aa 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -994,7 +994,7 @@ def CombineParallelMatmul():
     the fused ops are applied to the combined matmul output before slicing.
 
     Currently, only a limited set of fused ops is supported. It includes bias 
add,
-    relu, gelu, and silu activation.
+    relu, gelu, gelu_tanh and silu activation.
 
     Returns
     -------
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index ec2205d1b7..384cab5f4d 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -31,6 +31,9 @@ RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(relu, "nn.relu", 
/*require_float_dtype=*/fal
 /* relax.nn.gelu */
 RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu, "nn.gelu", 
/*require_float_dtype=*/true);
 
+/* relax.nn.gelu_tanh */
+RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu_tanh, "nn.gelu_tanh", 
/*require_float_dtype=*/true);
+
 /* relax.nn.silu */
 RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(silu, "nn.silu", 
/*require_float_dtype=*/true);
 
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index 199cc691b9..38e605bb0b 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -51,6 +51,9 @@ Expr relu(Expr data);
 /*! \brief Gaussian Error Linear Units function. */
 Expr gelu(Expr data);
 
+/*! \brief Gaussian Error Linear Units function approximated by tanh. */
+Expr gelu_tanh(Expr data);
+
 /*! \brief Sigmoid Linear Unit function. */
 Expr silu(Expr data);
 
diff --git a/src/relax/transform/combine_parallel_matmul.cc 
b/src/relax/transform/combine_parallel_matmul.cc
index a7f8711a1f..6efa4552ac 100644
--- a/src/relax/transform/combine_parallel_matmul.cc
+++ b/src/relax/transform/combine_parallel_matmul.cc
@@ -169,6 +169,8 @@ runtime::TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, 
Var>)> GetRewriter(
           matmul_combined = relu(matmul_combined);
         } else if (*branch_info.activation == "relax.nn.gelu") {
           matmul_combined = gelu(matmul_combined);
+        } else if (*branch_info.activation == "relax.nn.gelu_tanh") {
+          matmul_combined = gelu_tanh(matmul_combined);
         } else if (*branch_info.activation == "relax.nn.silu") {
           matmul_combined = silu(matmul_combined);
         } else {
@@ -212,7 +214,8 @@ std::vector<BranchInfo> GetBranchInfo(Function f) {
   auto matmul_pat = IsOp("relax.matmul")(Wildcard(), Wildcard());
   auto bias_add_pat = IsOp("relax.add")(matmul_pat, bias_pat);
 
-  std::vector<std::string> activations{"relax.nn.relu", "relax.nn.gelu", 
"relax.nn.silu"};
+  std::vector<std::string> activations{"relax.nn.relu", "relax.nn.gelu", 
"relax.nn.gelu_tanh",
+                                       "relax.nn.silu"};
 
   std::vector<DFPattern> activation_pat, bias_activation_pat;
   for (const auto& act : activations) {
diff --git a/tests/python/relax/test_frontend_dynamo.py 
b/tests/python/relax/test_frontend_dynamo.py
index 72ea193a02..76d8d366c4 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -40,6 +40,7 @@ def test_relax_dynamo():
             return torch.nn.functional.relu(self.lin(x))
 
     model = Input1()
+
     ### construct the database
     @tvm.script.ir_module
     class Input1_ir:
@@ -328,6 +329,58 @@ def test_full():
     )
 
 
[email protected]_gpu
+def test_gelu():
+    import torch
+    from torch.nn import Module
+
+    class GeLU(Module):
+        def forward(self, input):
+            return torch.nn.functional.gelu(input)
+
+    class GeLUTanh(Module):
+        def forward(self, input):
+            return torch.nn.functional.gelu(input, approximate="tanh")
+
+    @I.ir_module
+    class ExpectedGeLU:
+        @R.function
+        def main(
+            inp_0: R.Tensor((128, 256), dtype="float32")
+        ) -> R.Tensor((128, 256), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((128, 256), dtype="float32") = R.nn.gelu(inp_0)
+                gv: R.Tensor((128, 256), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class ExpectedGeLUTanh:
+        @R.function
+        def main(
+            inp_0: R.Tensor((128, 256), dtype="float32")
+        ) -> R.Tensor((128, 256), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((128, 256), dtype="float32") = 
R.nn.gelu_tanh(inp_0)
+                gv: R.Tensor((128, 256), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_dynamo_model(
+        GeLU(),
+        [([128, 256], "float32")],
+        {},
+        ExpectedGeLU,
+    )
+
+    verify_dynamo_model(
+        GeLUTanh(),
+        [([128, 256], "float32")],
+        {},
+        ExpectedGeLUTanh,
+    )
+
+
 @tvm.testing.requires_gpu
 def test_masked_fill():
     import torch
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index d062a91c9d..1ff0569629 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -1164,6 +1164,159 @@ def test_gelu_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_gelu_tanh():
+    # fmt: off
+    @tvm.script.ir_module
+    class GeluTanh:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv: R.Tensor((2, 3), "float32") = R.nn.gelu_tanh(x)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
+            gv = R.call_tir(Expected.gelu_tanh, (x,), out_sinfo=R.Tensor((2, 
3), dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), 
T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            T_power = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_add = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_multiply_2 = T.alloc_buffer((T.int64(2), T.int64(3)))
+            compute = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_add_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_power"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(T_power[v_ax0, v_ax1])
+                    T_power[v_ax0, v_ax1] = T.pow(A[v_ax0, v_ax1], 
T.float32(3))
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_power[v_ax0, v_ax1])
+                    T.writes(T_multiply_1[v_ax0, v_ax1])
+                    T_multiply_1[v_ax0, v_ax1] = 
T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1], T_multiply_1[v_ax0, v_ax1])
+                    T.writes(T_add[v_ax0, v_ax1])
+                    T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + 
T_multiply_1[v_ax0, v_ax1]
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_multiply_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_add[v_ax0, v_ax1])
+                    T.writes(T_multiply_2[v_ax0, v_ax1])
+                    T_multiply_2[v_ax0, v_ax1] = 
T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
+            for i0, i1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(T_multiply_2[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.tanh(T_multiply_2[v_i0, v_i1])
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_add_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(compute[v_ax0, v_ax1])
+                    T.writes(T_add_1[v_ax0, v_ax1])
+                    T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, 
v_ax1]
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_multiply_2"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_add_1[v_ax0, v_ax1])
+                    T.writes(T_multiply[v_ax0, v_ax1])
+                    T_multiply[v_ax0, v_ax1] = T.float32(0.5) * T_add_1[v_ax0, 
v_ax1]
+
+
+    mod = LegalizeOps()(GeluTanh)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_gelu_tanh_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class GeluTanh:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):
+            m = T.int64()
+            n = T.int64()
+            gv: R.Tensor((m, n), "float32") = R.nn.gelu_tanh(x)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", 
"n"), dtype="float32"):
+            m = T.int64()
+            n = T.int64()
+            gv = R.call_tir(Expected.gelu_tanh, (x,), out_sinfo=R.Tensor((m, 
n), dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def gelu_tanh(var_A: T.handle, var_T_multiply: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            m, n = T.int64(), T.int64()
+            A = T.match_buffer(var_A, (m, n))
+            T_multiply = T.match_buffer(var_T_multiply, (m, n))
+            T_power = T.alloc_buffer((m, n))
+            T_multiply_1 = T.alloc_buffer((m, n))
+            T_add = T.alloc_buffer((m, n))
+            T_multiply_2 = T.alloc_buffer((m, n))
+            compute = T.alloc_buffer((m, n))
+            T_add_1 = T.alloc_buffer((m, n))
+            for ax0, ax1 in T.grid(m, n):
+                with T.block("T_power"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(T_power[v_ax0, v_ax1])
+                    T_power[v_ax0, v_ax1] = T.pow(A[v_ax0, v_ax1], 
T.float32(3))
+            for ax0, ax1 in T.grid(m, n):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_power[v_ax0, v_ax1])
+                    T.writes(T_multiply_1[v_ax0, v_ax1])
+                    T_multiply_1[v_ax0, v_ax1] = 
T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
+            for ax0, ax1 in T.grid(m, n):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1], T_multiply_1[v_ax0, v_ax1])
+                    T.writes(T_add[v_ax0, v_ax1])
+                    T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + 
T_multiply_1[v_ax0, v_ax1]
+            for ax0, ax1 in T.grid(m, n):
+                with T.block("T_multiply_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_add[v_ax0, v_ax1])
+                    T.writes(T_multiply_2[v_ax0, v_ax1])
+                    T_multiply_2[v_ax0, v_ax1] = 
T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
+            for i0, i1 in T.grid(m, n):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(T_multiply_2[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.tanh(T_multiply_2[v_i0, v_i1])
+            for ax0, ax1 in T.grid(m, n):
+                with T.block("T_add_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(compute[v_ax0, v_ax1])
+                    T.writes(T_add_1[v_ax0, v_ax1])
+                    T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, 
v_ax1]
+            for ax0, ax1 in T.grid(m, n):
+                with T.block("T_multiply_2"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(T_add_1[v_ax0, v_ax1])
+                    T.writes(T_multiply[v_ax0, v_ax1])
+                    T_multiply[v_ax0, v_ax1] = T.float32(0.5) * T_add_1[v_ax0, 
v_ax1]
+
+
+    mod = LegalizeOps()(GeluTanh)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_silu():
     # fmt: off
     @tvm.script.ir_module

Reply via email to