This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f23d6b2434 [Relay][Bugfix] fix the wrong implementation of Softplus in 
OneFlow (#15717)
f23d6b2434 is described below

commit f23d6b2434c66871cbb3341ce66075453f452184
Author: jikechao <[email protected]>
AuthorDate: Sun Sep 10 13:13:54 2023 +0800

    [Relay][Bugfix] fix the wrong implementation of Softplus in OneFlow (#15717)
    
    * Update test_forward.py
    
    * fix a bug in softplus
    
    * Update oneflow.py
---
 python/tvm/relay/frontend/oneflow.py          | 7 +++++--
 tests/python/frontend/oneflow/test_forward.py | 2 +-
 2 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/oneflow.py 
b/python/tvm/relay/frontend/oneflow.py
index 4f278d8249..7a713e5e15 100644
--- a/python/tvm/relay/frontend/oneflow.py
+++ b/python/tvm/relay/frontend/oneflow.py
@@ -1119,8 +1119,11 @@ class Softplus(OneFlowOpConverter):
     def _impl_v1(cls, inputs, attrs, params):
         data = inputs[0]
         data_dtype = infer_type(data).checked_type.dtype
-        data = _op.exp(data) + _expr.const(1, dtype=data_dtype)
-        return _op.log(data)
+        beta = _expr.const(float(attrs.get("beta", 1.0)))
+        threshold = float(attrs.get("threshold", 20.0))
+        threshold_ = _op.full_like(data, fill_value=_expr.const(threshold))
+        softplus_value = _op.log(_op.exp(data * beta) + _expr.const(1.0, 
dtype=data_dtype)) / beta
+        return _op.where(_op.greater(data * beta, threshold_), data, 
softplus_value)
 
 
 class Softsign(OneFlowOpConverter):
diff --git a/tests/python/frontend/oneflow/test_forward.py 
b/tests/python/frontend/oneflow/test_forward.py
index cc9333cd03..17583b3c25 100644
--- a/tests/python/frontend/oneflow/test_forward.py
+++ b/tests/python/frontend/oneflow/test_forward.py
@@ -721,7 +721,7 @@ def test_activation():
 
     for device in ["llvm"]:
         verify_activation(model1, device=device)
-        # verify_activation(model2, device=device) # NO PASS
+        verify_activation(model2, device=device)
         verify_activation(model3, device=device)
         verify_activation(model4, device=device)
         verify_activation(model5, device=device)

Reply via email to