This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f23d6b2434 [Relay][Bugfix] fix the wrong implementation of Softplus in
OneFlow (#15717)
f23d6b2434 is described below
commit f23d6b2434c66871cbb3341ce66075453f452184
Author: jikechao <[email protected]>
AuthorDate: Sun Sep 10 13:13:54 2023 +0800
[Relay][Bugfix] fix the wrong implementation of Softplus in OneFlow (#15717)
* Update test_forward.py
* fix a bug in softplus
* Update oneflow.py
---
python/tvm/relay/frontend/oneflow.py | 7 +++++--
tests/python/frontend/oneflow/test_forward.py | 2 +-
2 files changed, 6 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/frontend/oneflow.py
b/python/tvm/relay/frontend/oneflow.py
index 4f278d8249..7a713e5e15 100644
--- a/python/tvm/relay/frontend/oneflow.py
+++ b/python/tvm/relay/frontend/oneflow.py
@@ -1119,8 +1119,11 @@ class Softplus(OneFlowOpConverter):
def _impl_v1(cls, inputs, attrs, params):
data = inputs[0]
data_dtype = infer_type(data).checked_type.dtype
- data = _op.exp(data) + _expr.const(1, dtype=data_dtype)
- return _op.log(data)
+ beta = _expr.const(float(attrs.get("beta", 1.0)))
+ threshold = float(attrs.get("threshold", 20.0))
+ threshold_ = _op.full_like(data, fill_value=_expr.const(threshold))
+ softplus_value = _op.log(_op.exp(data * beta) + _expr.const(1.0,
dtype=data_dtype)) / beta
+ return _op.where(_op.greater(data * beta, threshold_), data,
softplus_value)
class Softsign(OneFlowOpConverter):
diff --git a/tests/python/frontend/oneflow/test_forward.py
b/tests/python/frontend/oneflow/test_forward.py
index cc9333cd03..17583b3c25 100644
--- a/tests/python/frontend/oneflow/test_forward.py
+++ b/tests/python/frontend/oneflow/test_forward.py
@@ -721,7 +721,7 @@ def test_activation():
for device in ["llvm"]:
verify_activation(model1, device=device)
- # verify_activation(model2, device=device) # NO PASS
+ verify_activation(model2, device=device)
verify_activation(model3, device=device)
verify_activation(model4, device=device)
verify_activation(model5, device=device)