This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b1380059f0 [Unity] [Bugfix] Fix TypeError in TVM PyTorch frontend for 
LayerNorm operator (#15902)
b1380059f0 is described below

commit b1380059f0eeab78f36425193f33506abb9b576b
Author: Thrsu <[email protected]>
AuthorDate: Tue Oct 10 18:42:49 2023 +0800

    [Unity] [Bugfix] Fix TypeError in TVM PyTorch frontend for LayerNorm 
operator (#15902)
---
 python/tvm/relax/frontend/torch/fx_translator.py | 12 +++++--
 tests/python/relax/test_frontend_from_fx.py      | 40 ++++++++++++++++++++++++
 2 files changed, 49 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index a52b73f3ab..3d150e3eed 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -975,9 +975,15 @@ class TorchFXImporter:
         # functional.layer_norm
         if node.target not in self.named_modules:
             # static or symbolic
-            normalized_shape = (
-                node.args[1] if type(node.args[1]) == tuple else 
self.env[node.args[1]]
-            )
+            arg = node.args[1]
+            if isinstance(arg, tuple):
+                value = arg
+            else:
+                try:
+                    value = self.env[arg]
+                except TypeError:
+                    value = tuple(arg)
+            normalized_shape = value
             dim_num = len(normalized_shape)
             axes = list(range(-dim_num, 0))
 
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index c815051076..a1acff4974 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1152,6 +1152,46 @@ def test_functional_layernorm():
     binding = {}
     verify_model(model, input_info, binding, expected2)
 
+    class LayerNorm3(Module):
+        def __init__(self, shape):
+            super().__init__()
+            self.shape = shape
+            self.weight = torch.nn.Parameter(torch.ones(shape))
+            self.bias = torch.nn.Parameter(torch.zeros(shape))
+
+        def forward(self, input):
+            return torch.nn.functional.layer_norm(input, self.shape, 
self.weight, self.bias, 1e-5)
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor([10, 10], dtype="float32"),
+            w2: R.Tensor([10, 10], dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.layer_norm(
+                    input_1,
+                    w1,
+                    w2,
+                    axes=[-2, -1],
+                    epsilon=1e-05,
+                    center=True,
+                    scale=True,
+                )
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    model = LayerNorm3([10, 10])
+    binding = {
+        "w1": model.weight.detach().numpy(),
+        "w2": model.bias.detach().numpy(),
+    }
+    verify_model(model, input_info, binding, expected3)
+
 
 def test_cross_entropy():
     input_info = [([3, 2], "float32"), ([3], "int32")]

Reply via email to