This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new d6015c5643 [Unity][BugFix] Fix a bug in relax gelu_tanh computation 
(#16188)
d6015c5643 is described below

commit d6015c564399ea98f94922497022f83f3f81d97e
Author: Rick Zhou <[email protected]>
AuthorDate: Thu Nov 30 17:43:24 2023 -0500

    [Unity][BugFix] Fix a bug in relax gelu_tanh computation (#16188)
    
    Fix a bug in relax gelu_tanh computation
---
 python/tvm/relax/frontend/nn/op.py                 | 11 +---
 python/tvm/relax/transform/legalize_ops/nn.py      | 14 ++--
 .../python/relax/test_transform_legalize_ops_nn.py | 75 +++++++++++++---------
 3 files changed, 54 insertions(+), 46 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index 03080615e9..b95ceac4ed 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -827,17 +827,8 @@ def gelu(x: Tensor, approximate: Optional[str] = None, 
name: str = "gelu") -> Te
     ----
     The input tensor is required to have float dtype
     """
-    dtype = x._expr.struct_info.dtype
     if approximate == "tanh":
-        tanh_const = rx.const(1 + np.tanh(np.sqrt(2 / np.pi)), dtype=dtype)
-        gelu_out = (
-            rx.const(0.5, dtype)
-            * x._expr
-            * (
-                tanh_const
-                * (x._expr + (rx.const(0.044715, dtype) * _op.power(x._expr, 
rx.const(3, "int32"))))
-            )
-        )
+        gelu_out = _op.nn.gelu_tanh(x._expr)
     else:
         gelu_out = _op.nn.gelu(x._expr)
     return _wrap_nested(gelu_out, name)
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index a82f54b84c..f2453a67b6 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -302,11 +302,15 @@ def _nn_gelu(bb: BlockBuilder, call: Call) -> Expr:
 def _nn_gelu_tanh(bb: BlockBuilder, call: Call) -> Expr:
     def te_gelu_tanh(x: te.Tensor):
         dtype = x.dtype
-        return tir.const(0.5, dtype) * (
-            tir.const(1.0, dtype)
-            + topi.tanh(
-                tir.const(math.sqrt(2.0 / math.pi), dtype)
-                * (x + tir.const(0.044715, dtype) * topi.power(x, 3))
+        return (
+            tir.const(0.5, dtype)
+            * x
+            * (
+                tir.const(1.0, dtype)
+                + topi.tanh(
+                    tir.const(math.sqrt(2.0 / math.pi), dtype)
+                    * (x + tir.const(0.044715, dtype) * topi.power(x, 3))
+                )
             )
         )
 
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 47f00dca87..74da77f7d8 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -1258,12 +1258,19 @@ def test_gelu_tanh():
         @T.prim_func(private=True)
         def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), 
T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")):
             T.func_attr({"tir.noalias": T.bool(True)})
-            T_power = T.alloc_buffer((T.int64(2), T.int64(3)))
             T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
-            T_add = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_power = T.alloc_buffer((T.int64(2), T.int64(3)))
             T_multiply_2 = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_add = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_multiply_3 = T.alloc_buffer((T.int64(2), T.int64(3)))
             compute = T.alloc_buffer((T.int64(2), T.int64(3)))
             T_add_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(T_multiply_1[v_ax0, v_ax1])
+                    T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0, 
v_ax1]
             for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("T_power"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
@@ -1271,29 +1278,29 @@ def test_gelu_tanh():
                     T.writes(T_power[v_ax0, v_ax1])
                     T_power[v_ax0, v_ax1] = T.pow(A[v_ax0, v_ax1], 
T.float32(3))
             for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("T_multiply"):
+                with T.block("T_multiply_1"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                     T.reads(T_power[v_ax0, v_ax1])
-                    T.writes(T_multiply_1[v_ax0, v_ax1])
-                    T_multiply_1[v_ax0, v_ax1] = 
T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
+                    T.writes(T_multiply_2[v_ax0, v_ax1])
+                    T_multiply_2[v_ax0, v_ax1] = 
T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
             for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("T_add"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(A[v_ax0, v_ax1], T_multiply_1[v_ax0, v_ax1])
+                    T.reads(A[v_ax0, v_ax1], T_multiply_2[v_ax0, v_ax1])
                     T.writes(T_add[v_ax0, v_ax1])
-                    T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + 
T_multiply_1[v_ax0, v_ax1]
+                    T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + 
T_multiply_2[v_ax0, v_ax1]
             for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("T_multiply_1"):
+                with T.block("T_multiply_2"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                     T.reads(T_add[v_ax0, v_ax1])
-                    T.writes(T_multiply_2[v_ax0, v_ax1])
-                    T_multiply_2[v_ax0, v_ax1] = 
T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
+                    T.writes(T_multiply_3[v_ax0, v_ax1])
+                    T_multiply_3[v_ax0, v_ax1] = 
T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
             for i0, i1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("compute"):
                     v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(T_multiply_2[v_i0, v_i1])
+                    T.reads(T_multiply_3[v_i0, v_i1])
                     T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.tanh(T_multiply_2[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.tanh(T_multiply_3[v_i0, v_i1])
             for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
                 with T.block("T_add_1"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
@@ -1301,12 +1308,11 @@ def test_gelu_tanh():
                     T.writes(T_add_1[v_ax0, v_ax1])
                     T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, 
v_ax1]
             for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
-                with T.block("T_multiply_2"):
+                with T.block("T_multiply_3"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(T_add_1[v_ax0, v_ax1])
+                    T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1])
                     T.writes(T_multiply[v_ax0, v_ax1])
-                    T_multiply[v_ax0, v_ax1] = T.float32(0.5) * T_add_1[v_ax0, 
v_ax1]
-
+                    T_multiply[v_ax0, v_ax1] = T_multiply_1[v_ax0, v_ax1] * 
T_add_1[v_ax0, v_ax1]
 
     mod = LegalizeOps()(GeluTanh)
     tvm.ir.assert_structural_equal(mod, Expected)
@@ -1338,12 +1344,19 @@ def test_gelu_tanh_symbolic():
             m, n = T.int64(), T.int64()
             A = T.match_buffer(var_A, (m, n))
             T_multiply = T.match_buffer(var_T_multiply, (m, n))
-            T_power = T.alloc_buffer((m, n))
             T_multiply_1 = T.alloc_buffer((m, n))
-            T_add = T.alloc_buffer((m, n))
+            T_power = T.alloc_buffer((m, n))
             T_multiply_2 = T.alloc_buffer((m, n))
+            T_add = T.alloc_buffer((m, n))
+            T_multiply_3 = T.alloc_buffer((m, n))
             compute = T.alloc_buffer((m, n))
             T_add_1 = T.alloc_buffer((m, n))
+            for ax0, ax1 in T.grid(m, n):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A[v_ax0, v_ax1])
+                    T.writes(T_multiply_1[v_ax0, v_ax1])
+                    T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0, 
v_ax1]
             for ax0, ax1 in T.grid(m, n):
                 with T.block("T_power"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
@@ -1351,29 +1364,29 @@ def test_gelu_tanh_symbolic():
                     T.writes(T_power[v_ax0, v_ax1])
                     T_power[v_ax0, v_ax1] = T.pow(A[v_ax0, v_ax1], 
T.float32(3))
             for ax0, ax1 in T.grid(m, n):
-                with T.block("T_multiply"):
+                with T.block("T_multiply_1"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                     T.reads(T_power[v_ax0, v_ax1])
-                    T.writes(T_multiply_1[v_ax0, v_ax1])
-                    T_multiply_1[v_ax0, v_ax1] = 
T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
+                    T.writes(T_multiply_2[v_ax0, v_ax1])
+                    T_multiply_2[v_ax0, v_ax1] = 
T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
             for ax0, ax1 in T.grid(m, n):
                 with T.block("T_add"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(A[v_ax0, v_ax1], T_multiply_1[v_ax0, v_ax1])
+                    T.reads(A[v_ax0, v_ax1], T_multiply_2[v_ax0, v_ax1])
                     T.writes(T_add[v_ax0, v_ax1])
-                    T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + 
T_multiply_1[v_ax0, v_ax1]
+                    T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + 
T_multiply_2[v_ax0, v_ax1]
             for ax0, ax1 in T.grid(m, n):
-                with T.block("T_multiply_1"):
+                with T.block("T_multiply_2"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                     T.reads(T_add[v_ax0, v_ax1])
-                    T.writes(T_multiply_2[v_ax0, v_ax1])
-                    T_multiply_2[v_ax0, v_ax1] = 
T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
+                    T.writes(T_multiply_3[v_ax0, v_ax1])
+                    T_multiply_3[v_ax0, v_ax1] = 
T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
             for i0, i1 in T.grid(m, n):
                 with T.block("compute"):
                     v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
-                    T.reads(T_multiply_2[v_i0, v_i1])
+                    T.reads(T_multiply_3[v_i0, v_i1])
                     T.writes(compute[v_i0, v_i1])
-                    compute[v_i0, v_i1] = T.tanh(T_multiply_2[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.tanh(T_multiply_3[v_i0, v_i1])
             for ax0, ax1 in T.grid(m, n):
                 with T.block("T_add_1"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
@@ -1381,11 +1394,11 @@ def test_gelu_tanh_symbolic():
                     T.writes(T_add_1[v_ax0, v_ax1])
                     T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0, 
v_ax1]
             for ax0, ax1 in T.grid(m, n):
-                with T.block("T_multiply_2"):
+                with T.block("T_multiply_3"):
                     v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
-                    T.reads(T_add_1[v_ax0, v_ax1])
+                    T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1])
                     T.writes(T_multiply[v_ax0, v_ax1])
-                    T_multiply[v_ax0, v_ax1] = T.float32(0.5) * T_add_1[v_ax0, 
v_ax1]
+                    T_multiply[v_ax0, v_ax1] = T_multiply_1[v_ax0, v_ax1] * 
T_add_1[v_ax0, v_ax1]
 
 
     mod = LegalizeOps()(GeluTanh)

Reply via email to