This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new bc7809535d [Unity][Frontend] Import `tanh` and fix `layer_norm`
(#14247)
bc7809535d is described below
commit bc7809535d03f690f6d40848e4eb9a53bd14f64c
Author: Chaofan Lin <[email protected]>
AuthorDate: Fri Mar 10 22:26:00 2023 +0800
[Unity][Frontend] Import `tanh` and fix `layer_norm` (#14247)
This PR provides some quick fixes for fx_translator to import tanh and fix
the error when importing torch.nn.functional.layer_norm.
---
python/tvm/relax/frontend/torch/fx_translator.py | 32 +++++++++
tests/python/relax/test_frontend_from_fx.py | 82 ++++++++++++++++++++++++
2 files changed, 114 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index a157b80337..41e8e775a4 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -609,8 +609,38 @@ class TorchFXImporter:
def _layer_norm(self, node: fx.node.Node) -> relax.Var:
import torch # type: ignore
+ import numpy as np # type: ignore
x = self.env[node.args[0]]
+
+ # functional.layer_norm
+ if node.target not in self.named_modules:
+ # static or symbolic
+ normalized_shape = (
+ node.args[1] if type(node.args[1]) == tuple else
self.env[node.args[1]]
+ )
+ dim_num = len(normalized_shape)
+ axes = list(range(-dim_num, 0))
+
+ gamma = self.env[node.kwargs["weight"]]
+ beta = node.kwargs["bias"]
+ if beta is None:
+ shape_tuple = [int(s) for s in normalized_shape.values]
+ beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype)
+ else:
+ beta = self.env[beta]
+ eps = node.kwargs["eps"]
+
+ return self.block_builder.emit(
+ relax.op.nn.layer_norm(
+ x,
+ gamma,
+ beta,
+ axes=axes,
+ epsilon=eps,
+ )
+ )
+
module = self.named_modules[node.target]
if module.elementwise_affine:
@@ -886,6 +916,7 @@ class TorchFXImporter:
"clamp": self._clamp,
"relu": lambda node:
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
"gelu": lambda node:
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])),
+ "tanh": lambda node:
self.block_builder.emit(relax.op.tanh(self.env[node.args[0]])),
"interpolate": self._interpolate,
"size": self._size,
"getattr": self._getattr,
@@ -893,6 +924,7 @@ class TorchFXImporter:
"contiguous": lambda node: self.env[node.args[0]],
"to": lambda node: self.env[node.args[0]],
"adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
+ "layer_norm": self._layer_norm,
}
def from_fx(
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 3447c99e5c..6467b6cf14 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -761,6 +761,58 @@ def test_layernorm():
verify_model(LayerNorm(), input_info, binding, expected1)
[email protected]_gpu
+def test_functional_layernorm():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class LayerNorm(Module):
+ def __init__(self, shape):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.ones(shape))
+ self.bias = torch.nn.Parameter(torch.zeros(shape))
+
+ def forward(self, input):
+ return torch.nn.functional.layer_norm(
+ input, self.weight.shape, self.weight, self.bias, 1e-5
+ )
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((10, 10), dtype="float32"),
+ w2: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.layer_norm(
+ input_1,
+ w1,
+ w2,
+ axes=[-2, -1],
+ epsilon=1e-05,
+ center=True,
+ scale=True,
+ )
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ model = LayerNorm((10, 10))
+ binding = {
+ "w1": model.weight.numpy(),
+ "w2": model.bias.numpy(),
+ }
+ verify_model(model, input_info, binding, expected1)
+
+
@tvm.testing.requires_gpu
def test_silu():
import torch
@@ -1490,6 +1542,36 @@ def test_gelu():
verify_model(Gelu(), input_info, {}, expected1)
[email protected]_gpu
+def test_tanh():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Tanh(Module):
+ def forward(self, input):
+ return torch.tanh(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.tanh(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Tanh(), input_info, {}, expected1)
+
+
@tvm.testing.requires_gpu
def test_clamp():
import torch