This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new bc7809535d [Unity][Frontend] Import `tanh` and fix `layer_norm` 
(#14247)
bc7809535d is described below

commit bc7809535d03f690f6d40848e4eb9a53bd14f64c
Author: Chaofan Lin <[email protected]>
AuthorDate: Fri Mar 10 22:26:00 2023 +0800

    [Unity][Frontend] Import `tanh` and fix `layer_norm` (#14247)
    
    This PR provides some quick fixes for fx_translator to import tanh and fix 
the error when importing torch.nn.functional.layer_norm.
---
 python/tvm/relax/frontend/torch/fx_translator.py | 32 +++++++++
 tests/python/relax/test_frontend_from_fx.py      | 82 ++++++++++++++++++++++++
 2 files changed, 114 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index a157b80337..41e8e775a4 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -609,8 +609,38 @@ class TorchFXImporter:
 
     def _layer_norm(self, node: fx.node.Node) -> relax.Var:
         import torch  # type: ignore
+        import numpy as np  # type: ignore
 
         x = self.env[node.args[0]]
+
+        # functional.layer_norm
+        if node.target not in self.named_modules:
+            # static or symbolic
+            normalized_shape = (
+                node.args[1] if type(node.args[1]) == tuple else 
self.env[node.args[1]]
+            )
+            dim_num = len(normalized_shape)
+            axes = list(range(-dim_num, 0))
+
+            gamma = self.env[node.kwargs["weight"]]
+            beta = node.kwargs["bias"]
+            if beta is None:
+                shape_tuple = [int(s) for s in normalized_shape.values]
+                beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype)
+            else:
+                beta = self.env[beta]
+            eps = node.kwargs["eps"]
+
+            return self.block_builder.emit(
+                relax.op.nn.layer_norm(
+                    x,
+                    gamma,
+                    beta,
+                    axes=axes,
+                    epsilon=eps,
+                )
+            )
+
         module = self.named_modules[node.target]
 
         if module.elementwise_affine:
@@ -886,6 +916,7 @@ class TorchFXImporter:
             "clamp": self._clamp,
             "relu": lambda node: 
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
             "gelu": lambda node: 
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])),
+            "tanh": lambda node: 
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
             "interpolate": self._interpolate,
             "size": self._size,
             "getattr": self._getattr,
@@ -893,6 +924,7 @@ class TorchFXImporter:
             "contiguous": lambda node: self.env[node.args[0]],
             "to": lambda node: self.env[node.args[0]],
             "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
+            "layer_norm": self._layer_norm,
         }
 
     def from_fx(
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 3447c99e5c..6467b6cf14 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -761,6 +761,58 @@ def test_layernorm():
     verify_model(LayerNorm(), input_info, binding, expected1)
 
 
[email protected]_gpu
+def test_functional_layernorm():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class LayerNorm(Module):
+        def __init__(self, shape):
+            super().__init__()
+            self.weight = torch.nn.Parameter(torch.ones(shape))
+            self.bias = torch.nn.Parameter(torch.zeros(shape))
+
+        def forward(self, input):
+            return torch.nn.functional.layer_norm(
+                input, self.weight.shape, self.weight, self.bias, 1e-5
+            )
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((10, 10), dtype="float32"),
+            w2: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.layer_norm(
+                    input_1,
+                    w1,
+                    w2,
+                    axes=[-2, -1],
+                    epsilon=1e-05,
+                    center=True,
+                    scale=True,
+                )
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    model = LayerNorm((10, 10))
+    binding = {
+        "w1": model.weight.numpy(),
+        "w2": model.bias.numpy(),
+    }
+    verify_model(model, input_info, binding, expected1)
+
+
 @tvm.testing.requires_gpu
 def test_silu():
     import torch
@@ -1490,6 +1542,36 @@ def test_gelu():
     verify_model(Gelu(), input_info, {}, expected1)
 
 
[email protected]_gpu
+def test_tanh():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Tanh(Module):
+        def forward(self, input):
+            return torch.tanh(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Tanh(), input_info, {}, expected1)
+
+
 @tvm.testing.requires_gpu
 def test_clamp():
     import torch

Reply via email to