This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new b1380059f0 [Unity] [Bugfix] Fix TypeError in TVM PyTorch frontend for
LayerNorm operator (#15902)
b1380059f0 is described below
commit b1380059f0eeab78f36425193f33506abb9b576b
Author: Thrsu <[email protected]>
AuthorDate: Tue Oct 10 18:42:49 2023 +0800
[Unity] [Bugfix] Fix TypeError in TVM PyTorch frontend for LayerNorm
operator (#15902)
---
python/tvm/relax/frontend/torch/fx_translator.py | 12 +++++--
tests/python/relax/test_frontend_from_fx.py | 40 ++++++++++++++++++++++++
2 files changed, 49 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index a52b73f3ab..3d150e3eed 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -975,9 +975,15 @@ class TorchFXImporter:
# functional.layer_norm
if node.target not in self.named_modules:
# static or symbolic
- normalized_shape = (
- node.args[1] if type(node.args[1]) == tuple else
self.env[node.args[1]]
- )
+ arg = node.args[1]
+ if isinstance(arg, tuple):
+ value = arg
+ else:
+ try:
+ value = self.env[arg]
+ except TypeError:
+ value = tuple(arg)
+ normalized_shape = value
dim_num = len(normalized_shape)
axes = list(range(-dim_num, 0))
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index c815051076..a1acff4974 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1152,6 +1152,46 @@ def test_functional_layernorm():
binding = {}
verify_model(model, input_info, binding, expected2)
+ class LayerNorm3(Module):
+ def __init__(self, shape):
+ super().__init__()
+ self.shape = shape
+ self.weight = torch.nn.Parameter(torch.ones(shape))
+ self.bias = torch.nn.Parameter(torch.zeros(shape))
+
+ def forward(self, input):
+ return torch.nn.functional.layer_norm(input, self.shape,
self.weight, self.bias, 1e-5)
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor([10, 10], dtype="float32"),
+ w2: R.Tensor([10, 10], dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.layer_norm(
+ input_1,
+ w1,
+ w2,
+ axes=[-2, -1],
+ epsilon=1e-05,
+ center=True,
+ scale=True,
+ )
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ model = LayerNorm3([10, 10])
+ binding = {
+ "w1": model.weight.detach().numpy(),
+ "w2": model.bias.detach().numpy(),
+ }
+ verify_model(model, input_info, binding, expected3)
+
def test_cross_entropy():
input_info = [([3, 2], "float32"), ([3], "int32")]