This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cc7eb2faae [Relax] [PyTorch] Add support for torch.nn.Hardswish 
(#17084)
cc7eb2faae is described below

commit cc7eb2faae3444ee02b142a5aea237dd1db6d29a
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Thu Jun 13 02:09:50 2024 +0900

    [Relax] [PyTorch] Add support for torch.nn.Hardswish (#17084)
    
    * add hardswish support to fx_frontend
    
    * run ./tests/lint/git-black.sh -i --rev upstream/main
    
    * fix ci lint error
---
 python/tvm/relax/frontend/torch/fx_translator.py | 11 ++++++++
 tests/python/relax/test_frontend_from_fx.py      | 36 ++++++++++++++++++++++++
 2 files changed, 47 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index e26e9bc7dc..a5efcce278 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -243,6 +243,15 @@ class TorchFXImporter:
         else:
             raise KeyError("Unregonized approximate algorithm for gelu: 
{}.".format(approximate))
 
+    def _hardswish(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dtype = x.struct_info.dtype
+        x0 = relax.op.add(x, relax.const(3, dtype))
+        x1 = relax.op.clip(x0, 0, 6)
+        x2 = relax.op.divide(x1, relax.const(6, dtype))
+        return self.block_builder.emit(relax.op.multiply(x, x2))
+
     ########## Compare ##########
 
     def _lt(self, node: fx.node.Node) -> relax.Expr:
@@ -1358,6 +1367,7 @@ class TorchFXImporter:
             nn.Sigmoid: self._sigmoid,
             nn.Tanh: lambda node: 
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
             nn.SiLU: lambda node: 
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
+            nn.Hardswish: self._hardswish,
             nn.Flatten: self._flatten,
             nn.BatchNorm2d: self._batch_norm_2d,
             nn.LayerNorm: self._layer_norm,
@@ -1437,6 +1447,7 @@ class TorchFXImporter:
             "leaky_relu": self._leakyrelu,
             "gelu": self._gelu,
             "silu": lambda node: 
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
+            "hardswish": self._hardswish,
             "interpolate": self._interpolate,
             "size": self._size,
             "getattr": self._getattr,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index dfa5cad4a5..49131b5ff8 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1416,6 +1416,42 @@ def test_silu():
     verify_model(SiLU2(), input_info, {}, expected1)
 
 
+def test_hardswish():
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Hardswish(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.hs = torch.nn.Hardswish()
+
+        def forward(self, input):
+            return self.hs(input)
+
+    class Hardswish2(torch.nn.Module):
+        def forward(self, input):
+            return torch.nn.functional.hardswish(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, 
R.const(3, "float32"))
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 
6)
+                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+                    lv1, R.const(6, "float32")
+                )
+                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(inp_0, lv2)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    verify_model(Hardswish(), input_info, {}, expected1)
+    verify_model(Hardswish2(), input_info, {}, expected1)
+
+
 def test_groupnorm():
     import torch
     from torch.nn import Module

Reply via email to