This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7a38477b2f [Pytorch][Relay] aten::_weight_norm implementation (#13661)
7a38477b2f is described below

commit 7a38477b2f0c72c4c96645440fb7f8d07e4a25b3
Author: Matveenko Valery <[email protected]>
AuthorDate: Tue Dec 27 08:19:26 2022 +0100

    [Pytorch][Relay] aten::_weight_norm implementation (#13661)
    
    Add implementation for pytorch weight normalization
---
 python/tvm/relay/frontend/pytorch.py          | 15 +++++++++++++++
 tests/python/frontend/pytorch/test_forward.py | 24 ++++++++++++++++++++++++
 2 files changed, 39 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index b9d167ad2d..491c140c5c 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -3514,6 +3514,20 @@ class PyTorchOpConverter:
         _, indices = _expr.TupleWrapper(output, 2)
         return indices
 
+    def weight_norm(self, inputs, input_types):
+        weight_v, weight_g = inputs[0], inputs[1]
+        dim = inputs[2]
+        dtype = input_types[0]
+        order = 2.0
+        reci_order = _expr.const(1.0 / order, dtype=dtype)
+        order = _expr.const(order)
+
+        norm_v = _op.power(
+            _op.reduce.sum(_op.power(_op.abs(weight_v), order), axis=dim, 
exclude=2, keepdims=True),
+            reci_order,
+        )
+        return weight_g * (weight_v / norm_v)
+
     # Operator mappings
     def create_convert_map(self):
         self.convert_map = {
@@ -3781,6 +3795,7 @@ class PyTorchOpConverter:
             "aten::__lshift__": self.make_elemwise("left_shift"),
             "aten::__rshift__": self.make_elemwise("right_shift"),
             "aten::multinomial": self.multinomial,
+            "aten::_weight_norm": self.weight_norm,
         }
 
     def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 35242fbf7d..0035d202de 100755
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -5038,6 +5038,30 @@ def test_multinomial():
     )
 
 
+def test_weight_norm():
+    """Test for atten::_weight_norm"""
+    in_channels = 32
+    out_channels = 64
+    input_data_conv = torch.rand((1, in_channels, 32, 32)).float()
+
+    conv_wn = torch.nn.utils.weight_norm(torch.nn.Conv2d(in_channels, 
out_channels, kernel_size=3))
+    verify_model(conv_wn.eval().float(), input_data_conv)
+
+    conv_wn_groups = torch.nn.utils.weight_norm(
+        torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, groups=2)
+    )
+    verify_model(conv_wn_groups.eval().float(), input_data_conv)
+
+    conv_wn = torch.nn.utils.weight_norm(
+        torch.nn.Conv2d(in_channels, out_channels, kernel_size=3), dim=1
+    )
+    verify_model(conv_wn.eval().float(), input_data_conv)
+
+    linear_wn = torch.nn.utils.weight_norm(torch.nn.Linear(in_channels, 
out_channels))
+    input_data_linear = torch.rand((128, in_channels)).float()
+    verify_model(linear_wn.eval().float(), input_data_linear)
+
+
 @tvm.testing.uses_gpu
 def test_baddbmm():
     def test_fn(alpha, beta):

Reply via email to