This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d6015c5643 [Unity][BugFix] Fix a bug in relax gelu_tanh computation
(#16188)
d6015c5643 is described below
commit d6015c564399ea98f94922497022f83f3f81d97e
Author: Rick Zhou <[email protected]>
AuthorDate: Thu Nov 30 17:43:24 2023 -0500
[Unity][BugFix] Fix a bug in relax gelu_tanh computation (#16188)
Fix a bug in relax gelu_tanh computation
---
python/tvm/relax/frontend/nn/op.py | 11 +---
python/tvm/relax/transform/legalize_ops/nn.py | 14 ++--
.../python/relax/test_transform_legalize_ops_nn.py | 75 +++++++++++++---------
3 files changed, 54 insertions(+), 46 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index 03080615e9..b95ceac4ed 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -827,17 +827,8 @@ def gelu(x: Tensor, approximate: Optional[str] = None,
name: str = "gelu") -> Te
----
The input tensor is required to have float dtype
"""
- dtype = x._expr.struct_info.dtype
if approximate == "tanh":
- tanh_const = rx.const(1 + np.tanh(np.sqrt(2 / np.pi)), dtype=dtype)
- gelu_out = (
- rx.const(0.5, dtype)
- * x._expr
- * (
- tanh_const
- * (x._expr + (rx.const(0.044715, dtype) * _op.power(x._expr,
rx.const(3, "int32"))))
- )
- )
+ gelu_out = _op.nn.gelu_tanh(x._expr)
else:
gelu_out = _op.nn.gelu(x._expr)
return _wrap_nested(gelu_out, name)
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index a82f54b84c..f2453a67b6 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -302,11 +302,15 @@ def _nn_gelu(bb: BlockBuilder, call: Call) -> Expr:
def _nn_gelu_tanh(bb: BlockBuilder, call: Call) -> Expr:
def te_gelu_tanh(x: te.Tensor):
dtype = x.dtype
- return tir.const(0.5, dtype) * (
- tir.const(1.0, dtype)
- + topi.tanh(
- tir.const(math.sqrt(2.0 / math.pi), dtype)
- * (x + tir.const(0.044715, dtype) * topi.power(x, 3))
+ return (
+ tir.const(0.5, dtype)
+ * x
+ * (
+ tir.const(1.0, dtype)
+ + topi.tanh(
+ tir.const(math.sqrt(2.0 / math.pi), dtype)
+ * (x + tir.const(0.044715, dtype) * topi.power(x, 3))
+ )
)
)
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 47f00dca87..74da77f7d8 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -1258,12 +1258,19 @@ def test_gelu_tanh():
@T.prim_func(private=True)
def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"),
T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
- T_power = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
- T_add = T.alloc_buffer((T.int64(2), T.int64(3)))
+ T_power = T.alloc_buffer((T.int64(2), T.int64(3)))
T_multiply_2 = T.alloc_buffer((T.int64(2), T.int64(3)))
+ T_add = T.alloc_buffer((T.int64(2), T.int64(3)))
+ T_multiply_3 = T.alloc_buffer((T.int64(2), T.int64(3)))
compute = T.alloc_buffer((T.int64(2), T.int64(3)))
T_add_1 = T.alloc_buffer((T.int64(2), T.int64(3)))
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1])
+ T.writes(T_multiply_1[v_ax0, v_ax1])
+ T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0,
v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_power"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
@@ -1271,29 +1278,29 @@ def test_gelu_tanh():
T.writes(T_power[v_ax0, v_ax1])
T_power[v_ax0, v_ax1] = T.pow(A[v_ax0, v_ax1],
T.float32(3))
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("T_multiply"):
+ with T.block("T_multiply_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_power[v_ax0, v_ax1])
- T.writes(T_multiply_1[v_ax0, v_ax1])
- T_multiply_1[v_ax0, v_ax1] =
T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
+ T.writes(T_multiply_2[v_ax0, v_ax1])
+ T_multiply_2[v_ax0, v_ax1] =
T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(A[v_ax0, v_ax1], T_multiply_1[v_ax0, v_ax1])
+ T.reads(A[v_ax0, v_ax1], T_multiply_2[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
- T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] +
T_multiply_1[v_ax0, v_ax1]
+ T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] +
T_multiply_2[v_ax0, v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("T_multiply_1"):
+ with T.block("T_multiply_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_add[v_ax0, v_ax1])
- T.writes(T_multiply_2[v_ax0, v_ax1])
- T_multiply_2[v_ax0, v_ax1] =
T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
+ T.writes(T_multiply_3[v_ax0, v_ax1])
+ T_multiply_3[v_ax0, v_ax1] =
T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(T_multiply_2[v_i0, v_i1])
+ T.reads(T_multiply_3[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.tanh(T_multiply_2[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.tanh(T_multiply_3[v_i0, v_i1])
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
@@ -1301,12 +1308,11 @@ def test_gelu_tanh():
T.writes(T_add_1[v_ax0, v_ax1])
T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0,
v_ax1]
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
- with T.block("T_multiply_2"):
+ with T.block("T_multiply_3"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(T_add_1[v_ax0, v_ax1])
+ T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
- T_multiply[v_ax0, v_ax1] = T.float32(0.5) * T_add_1[v_ax0,
v_ax1]
-
+ T_multiply[v_ax0, v_ax1] = T_multiply_1[v_ax0, v_ax1] *
T_add_1[v_ax0, v_ax1]
mod = LegalizeOps()(GeluTanh)
tvm.ir.assert_structural_equal(mod, Expected)
@@ -1338,12 +1344,19 @@ def test_gelu_tanh_symbolic():
m, n = T.int64(), T.int64()
A = T.match_buffer(var_A, (m, n))
T_multiply = T.match_buffer(var_T_multiply, (m, n))
- T_power = T.alloc_buffer((m, n))
T_multiply_1 = T.alloc_buffer((m, n))
- T_add = T.alloc_buffer((m, n))
+ T_power = T.alloc_buffer((m, n))
T_multiply_2 = T.alloc_buffer((m, n))
+ T_add = T.alloc_buffer((m, n))
+ T_multiply_3 = T.alloc_buffer((m, n))
compute = T.alloc_buffer((m, n))
T_add_1 = T.alloc_buffer((m, n))
+ for ax0, ax1 in T.grid(m, n):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A[v_ax0, v_ax1])
+ T.writes(T_multiply_1[v_ax0, v_ax1])
+ T_multiply_1[v_ax0, v_ax1] = T.float32(0.5) * A[v_ax0,
v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_power"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
@@ -1351,29 +1364,29 @@ def test_gelu_tanh_symbolic():
T.writes(T_power[v_ax0, v_ax1])
T_power[v_ax0, v_ax1] = T.pow(A[v_ax0, v_ax1],
T.float32(3))
for ax0, ax1 in T.grid(m, n):
- with T.block("T_multiply"):
+ with T.block("T_multiply_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_power[v_ax0, v_ax1])
- T.writes(T_multiply_1[v_ax0, v_ax1])
- T_multiply_1[v_ax0, v_ax1] =
T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
+ T.writes(T_multiply_2[v_ax0, v_ax1])
+ T_multiply_2[v_ax0, v_ax1] =
T.float32(0.044714999999999998) * T_power[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(A[v_ax0, v_ax1], T_multiply_1[v_ax0, v_ax1])
+ T.reads(A[v_ax0, v_ax1], T_multiply_2[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
- T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] +
T_multiply_1[v_ax0, v_ax1]
+ T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] +
T_multiply_2[v_ax0, v_ax1]
for ax0, ax1 in T.grid(m, n):
- with T.block("T_multiply_1"):
+ with T.block("T_multiply_2"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(T_add[v_ax0, v_ax1])
- T.writes(T_multiply_2[v_ax0, v_ax1])
- T_multiply_2[v_ax0, v_ax1] =
T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
+ T.writes(T_multiply_3[v_ax0, v_ax1])
+ T_multiply_3[v_ax0, v_ax1] =
T.float32(0.79788456080286541) * T_add[v_ax0, v_ax1]
for i0, i1 in T.grid(m, n):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(T_multiply_2[v_i0, v_i1])
+ T.reads(T_multiply_3[v_i0, v_i1])
T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.tanh(T_multiply_2[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.tanh(T_multiply_3[v_i0, v_i1])
for ax0, ax1 in T.grid(m, n):
with T.block("T_add_1"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
@@ -1381,11 +1394,11 @@ def test_gelu_tanh_symbolic():
T.writes(T_add_1[v_ax0, v_ax1])
T_add_1[v_ax0, v_ax1] = T.float32(1) + compute[v_ax0,
v_ax1]
for ax0, ax1 in T.grid(m, n):
- with T.block("T_multiply_2"):
+ with T.block("T_multiply_3"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
- T.reads(T_add_1[v_ax0, v_ax1])
+ T.reads(T_multiply_1[v_ax0, v_ax1], T_add_1[v_ax0, v_ax1])
T.writes(T_multiply[v_ax0, v_ax1])
- T_multiply[v_ax0, v_ax1] = T.float32(0.5) * T_add_1[v_ax0,
v_ax1]
+ T_multiply[v_ax0, v_ax1] = T_multiply_1[v_ax0, v_ax1] *
T_add_1[v_ax0, v_ax1]
mod = LegalizeOps()(GeluTanh)