This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d7ae4c74fc [Relax] [PyTorch] Add support for torch.nn.Hardsigmoid
(#17085)
d7ae4c74fc is described below
commit d7ae4c74fc0363f36fc5c0fdc2d40c2e64d5ae9c
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Jun 14 03:24:10 2024 +0900
[Relax] [PyTorch] Add support for torch.nn.Hardsigmoid (#17085)
add hardsigmoid support to fx_frontend
---
python/tvm/relax/frontend/torch/fx_translator.py | 10 +++++++
tests/python/relax/test_frontend_from_fx.py | 35 ++++++++++++++++++++++++
2 files changed, 45 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index a5efcce278..5ed0f18deb 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -243,6 +243,14 @@ class TorchFXImporter:
else:
raise KeyError("Unregonized approximate algorithm for gelu:
{}.".format(approximate))
+ def _hardsigmoid(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dtype = x.struct_info.dtype
+ x0 = relax.op.add(x, relax.const(3, dtype))
+ x1 = relax.op.clip(x0, 0, 6)
+ return self.block_builder.emit(relax.op.divide(x1, relax.const(6,
dtype)))
+
def _hardswish(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
@@ -1367,6 +1375,7 @@ class TorchFXImporter:
nn.Sigmoid: self._sigmoid,
nn.Tanh: lambda node:
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
nn.SiLU: lambda node:
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
+ nn.Hardsigmoid: self._hardsigmoid,
nn.Hardswish: self._hardswish,
nn.Flatten: self._flatten,
nn.BatchNorm2d: self._batch_norm_2d,
@@ -1447,6 +1456,7 @@ class TorchFXImporter:
"leaky_relu": self._leakyrelu,
"gelu": self._gelu,
"silu": lambda node:
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
+ "hardsigmoid": self._hardsigmoid,
"hardswish": self._hardswish,
"interpolate": self._interpolate,
"size": self._size,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 49131b5ff8..dd2719f8ce 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1416,6 +1416,41 @@ def test_silu():
verify_model(SiLU2(), input_info, {}, expected1)
+def test_hardsigmoid():
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Hardsigmoid(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.hs = torch.nn.Hardsigmoid()
+
+ def forward(self, input):
+ return self.hs(input)
+
+ class Hardsigmoid2(torch.nn.Module):
+ def forward(self, input):
+ return torch.nn.functional.hardsigmoid(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0,
R.const(3, "float32"))
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0,
6)
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ lv1, R.const(6, "float32")
+ )
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv2
+ R.output(gv)
+ return gv
+
+ verify_model(Hardsigmoid(), input_info, {}, expected1)
+ verify_model(Hardsigmoid2(), input_info, {}, expected1)
+
+
def test_hardswish():
input_info = [([1, 3, 10, 10], "float32")]