This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 602133e6b9 [bugfix][Relay] Fix softplus in paddlepaddle frontend 
(#14845)
602133e6b9 is described below

commit 602133e6b9a3c7925abd5bbdb315f92fa7170c93
Author: Qingchao Shen <[email protected]>
AuthorDate: Mon May 15 12:34:16 2023 +0800

    [bugfix][Relay] Fix softplus in paddlepaddle frontend (#14845)
    
    * fix softplus in paddlepaddle.py
    
    * add test case
    
    * Update test_forward.py
---
 python/tvm/relay/frontend/paddlepaddle.py          | 5 ++++-
 tests/python/frontend/paddlepaddle/test_forward.py | 4 +++-
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/paddlepaddle.py 
b/python/tvm/relay/frontend/paddlepaddle.py
index fdbc96676f..4268a4876a 100755
--- a/python/tvm/relay/frontend/paddlepaddle.py
+++ b/python/tvm/relay/frontend/paddlepaddle.py
@@ -2163,7 +2163,10 @@ def convert_softplus(g, op, block):
     dtype = infer_type(x).checked_type.dtype
     beta = op.attr("beta")
     beta = _expr.const(beta, dtype=dtype)
-    out = _op.log(_op.exp(x * beta) + _expr.const(1.0, dtype=dtype)) / beta
+    threshold = op.attr("threshold")
+    threshold = _expr.const(threshold, dtype=dtype)
+    out_softplus = _op.log(_op.exp(x * beta) + _expr.const(1.0, dtype=dtype)) 
/ beta
+    out = _op.where(_op.greater(x * beta, threshold), x, out_softplus)
     g.add_node(op.output("Out")[0], out)
 
 
diff --git a/tests/python/frontend/paddlepaddle/test_forward.py 
b/tests/python/frontend/paddlepaddle/test_forward.py
index 289fc0faa3..1555ba1aaa 100755
--- a/tests/python/frontend/paddlepaddle/test_forward.py
+++ b/tests/python/frontend/paddlepaddle/test_forward.py
@@ -1722,7 +1722,9 @@ def test_forward_sin():
 
 @run_math_api
 def test_forward_softplus():
-    pass
+    x = paddle.to_tensor([-0.4, 1], dtype="float32")
+    m = paddle.nn.Softplus(5, 1)
+    verify_model(m, [x])
 
 
 @run_math_api

Reply via email to