This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2cafa87b10 [Bugfix][Relay] Fix threshold calculation logic in PyTorch
frontend (#14820)
2cafa87b10 is described below
commit 2cafa87b10c6124f1a08af7ead712f29b9039762
Author: Qingchao Shen <[email protected]>
AuthorDate: Thu May 11 18:59:04 2023 +0800
[Bugfix][Relay] Fix threshold calculation logic in PyTorch frontend (#14820)
* fix threshold
* add test case
* Update pytorch.py
* Update pytorch.py
---
python/tvm/relay/frontend/pytorch.py | 6 +++++-
tests/python/frontend/pytorch/test_forward.py | 2 ++
2 files changed, 7 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index afd46b2001..5e2e6a5f5e 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1333,7 +1333,11 @@ class PyTorchOpConverter:
def threshold(self, inputs, input_types):
data = inputs[0]
- return _op.nn.relu(data)
+ threshold_f = float(inputs[1])
+ threshold_ = _op.full_like(inputs[0],
fill_value=_expr.const(threshold_f))
+ value_f = float(inputs[2])
+ value = _op.full_like(inputs[0], fill_value=_expr.const(value_f))
+ return _op.where(_op.greater(data, threshold_), data, value)
def contiguous(self, inputs, input_types):
return inputs[0]
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 9e5e9e22bc..fcaf7b7847 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1348,6 +1348,8 @@ def test_forward_threshold():
input_shape = [1, 3]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.Threshold(0, 0).float().eval(),
input_data=input_data)
+ input_data = torch.tensor([[-1.0, 2.0]], dtype=torch.float32)
+ verify_model(torch.nn.Threshold(1, 1).float().eval(),
input_data=input_data)
@tvm.testing.uses_gpu