This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e12ddca [FRONTEND][PYTORCH] Support fo nn.SiLU added (#8753)
e12ddca is described below
commit e12ddcafd74cc10cef343fc39a0c6a892a431650
Author: Alperen Bag <[email protected]>
AuthorDate: Sun Aug 15 07:01:08 2021 +0300
[FRONTEND][PYTORCH] Support fo nn.SiLU added (#8753)
---
python/tvm/relay/frontend/pytorch.py | 5 +++++
tests/python/frontend/pytorch/test_forward.py | 8 ++++++++
2 files changed, 13 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 9406c3b..7c10889 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -804,6 +804,10 @@ class PyTorchOpConverter:
alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(data))
+ _op.nn.relu(data)
)
+ def silu(self, inputs, input_types):
+ data = inputs[0]
+ return data * _op.tensor.sigmoid(data)
+
def log_sigmoid(self, inputs, input_types):
data = inputs[0]
return _op.log(_op.tensor.sigmoid(data))
@@ -2623,6 +2627,7 @@ class PyTorchOpConverter:
"aten::celu": self.celu,
"aten::gelu": self.gelu,
"aten::selu": self.selu,
+ "aten::silu": self.silu,
"aten::log_sigmoid": self.log_sigmoid,
"aten::adaptive_avg_pool2d": self.adaptive_avg_pool_2d,
"aten::adaptive_max_pool2d": self.adaptive_max_pool_2d,
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index c924e73..e2cb51a 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -701,6 +701,14 @@ def test_forward_selu():
@tvm.testing.uses_gpu
+def test_forward_silu():
+ torch.set_grad_enabled(False)
+ input_shape = [1, 3, 10, 10]
+ input_data = torch.rand(input_shape).float()
+ verify_model(torch.nn.SiLU().eval(), input_data=input_data)
+
+
[email protected]_gpu
def test_forward_softplus():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]