This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7a38477b2f [Pytorch][Relay] aten::_weight_norm implementation (#13661)
7a38477b2f is described below
commit 7a38477b2f0c72c4c96645440fb7f8d07e4a25b3
Author: Matveenko Valery <[email protected]>
AuthorDate: Tue Dec 27 08:19:26 2022 +0100
[Pytorch][Relay] aten::_weight_norm implementation (#13661)
Add implementation for pytorch weight normalization
---
python/tvm/relay/frontend/pytorch.py | 15 +++++++++++++++
tests/python/frontend/pytorch/test_forward.py | 24 ++++++++++++++++++++++++
2 files changed, 39 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index b9d167ad2d..491c140c5c 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -3514,6 +3514,20 @@ class PyTorchOpConverter:
_, indices = _expr.TupleWrapper(output, 2)
return indices
+ def weight_norm(self, inputs, input_types):
+ weight_v, weight_g = inputs[0], inputs[1]
+ dim = inputs[2]
+ dtype = input_types[0]
+ order = 2.0
+ reci_order = _expr.const(1.0 / order, dtype=dtype)
+ order = _expr.const(order)
+
+ norm_v = _op.power(
+ _op.reduce.sum(_op.power(_op.abs(weight_v), order), axis=dim,
exclude=2, keepdims=True),
+ reci_order,
+ )
+ return weight_g * (weight_v / norm_v)
+
# Operator mappings
def create_convert_map(self):
self.convert_map = {
@@ -3781,6 +3795,7 @@ class PyTorchOpConverter:
"aten::__lshift__": self.make_elemwise("left_shift"),
"aten::__rshift__": self.make_elemwise("right_shift"),
"aten::multinomial": self.multinomial,
+ "aten::_weight_norm": self.weight_norm,
}
def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 35242fbf7d..0035d202de 100755
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -5038,6 +5038,30 @@ def test_multinomial():
)
+def test_weight_norm():
+ """Test for atten::_weight_norm"""
+ in_channels = 32
+ out_channels = 64
+ input_data_conv = torch.rand((1, in_channels, 32, 32)).float()
+
+ conv_wn = torch.nn.utils.weight_norm(torch.nn.Conv2d(in_channels,
out_channels, kernel_size=3))
+ verify_model(conv_wn.eval().float(), input_data_conv)
+
+ conv_wn_groups = torch.nn.utils.weight_norm(
+ torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, groups=2)
+ )
+ verify_model(conv_wn_groups.eval().float(), input_data_conv)
+
+ conv_wn = torch.nn.utils.weight_norm(
+ torch.nn.Conv2d(in_channels, out_channels, kernel_size=3), dim=1
+ )
+ verify_model(conv_wn.eval().float(), input_data_conv)
+
+ linear_wn = torch.nn.utils.weight_norm(torch.nn.Linear(in_channels,
out_channels))
+ input_data_linear = torch.rand((128, in_channels)).float()
+ verify_model(linear_wn.eval().float(), input_data_linear)
+
+
@tvm.testing.uses_gpu
def test_baddbmm():
def test_fn(alpha, beta):