This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3fd3a63652 [Relay][Pytorch] Add support for `aten::linalg_vector_norm` 
(#16123)
3fd3a63652 is described below

commit 3fd3a63652545b34db4a0b25354a3ec30253511b
Author: Masahiro Hiramori <mhg00...@gmail.com>
AuthorDate: Sat Nov 25 23:27:22 2023 -0800

    [Relay][Pytorch] Add support for `aten::linalg_vector_norm` (#16123)
    
    * add support for `aten::linalg_vector_norm`
    
    * add dtype check assertion
    
    * add double-precision testcase
    
    * Re-enable test_forward_norm and test_forward_frobenius_norm
    
    * cleanup test
    
    * rename `ord`->`order` to avoid W0622(redefined-builtin)
---
 python/tvm/relay/frontend/pytorch.py          | 27 +++++++++++++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py | 27 +++++++++++++++++++++++++--
 2 files changed, 52 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index faed052a03..9374a24912 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -3865,6 +3865,32 @@ class PyTorchOpConverter:
         # Return
         return _op.scatter_nd(source, indices, values, mode)
 
+    def linalg_vector_norm(self, inputs, input_types):
+        data = inputs[0]
+        dtype = input_types[0]
+        ord = inputs[1]
+        dim = inputs[2]
+        keepdim = inputs[3]
+
+        assert dtype == "float32" or dtype == "float64"
+
+        if ord == 0:
+            return _op.reduce.sum(
+                _op.cast(_op.not_equal(data, _expr.const(0, dtype=dtype)), 
dtype=dtype),
+                axis=dim,
+                keepdims=keepdim,
+            )
+        elif ord == np.inf:
+            return _op.reduce.max(_op.abs(data), axis=dim, keepdims=keepdim)
+        elif ord == np.NINF:
+            return _op.reduce.min(_op.abs(data), axis=dim, keepdims=keepdim)
+        reci_ord = _expr.const(1.0 / ord, dtype=dtype)
+        ord = _expr.const(ord, dtype=dtype)
+        return _op.power(
+            _op.reduce.sum(_op.power(_op.abs(data), ord), axis=dim, 
keepdims=keepdim),
+            reci_ord,
+        )
+
     # Operator mappings
     def create_convert_map(self):
         self.convert_map = {
@@ -4140,6 +4166,7 @@ class PyTorchOpConverter:
             "aten::_weight_norm": self.weight_norm,
             "aten::copy_": self.inplace_copy,
             "aten::swapaxes": self.transpose,
+            "aten::linalg_vector_norm": self.linalg_vector_norm,
         }
 
     def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 2f346feced..d9ecbce265 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1814,7 +1814,6 @@ def test_forward_logsoftmax():
     verify_model(LogSoftmax1().float().eval(), input_data=input_data)
 
 
-@pytest.mark.skip(reason="unsupported op aten::linalg_vector_norm")
 @tvm.testing.uses_gpu
 def test_forward_norm():
     """test_forward_norm"""
@@ -1874,7 +1873,6 @@ def test_forward_norm():
     verify_model(Norm10().float().eval(), input_data=input_data)
 
 
-@pytest.mark.skip(reason="unsupported op aten::linalg_vector_norm")
 @tvm.testing.uses_gpu
 def test_forward_frobenius_norm():
     """test_forward_frobenius_norm"""
@@ -5466,6 +5464,31 @@ def test_swapaxes():
     verify_model(Swapaxes3().float().eval(), input_data=input_data)
 
 
+def test_linalg_vector_norm():
+    """test_linalg_vector_norm"""
+    torch.set_grad_enabled(False)
+
+    def test_fn(order):
+        return lambda x: torch.linalg.vector_norm(x, ord=order)
+
+    input_shape = [3, 3]
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(test_fn(order=2), input_data=input_data)
+    verify_model(test_fn(order=3.5), input_data=input_data)
+    verify_model(test_fn(order=np.inf), input_data=input_data)
+    verify_model(test_fn(order=np.NINF), input_data=input_data)
+    verify_model(test_fn(order=0), input_data=input_data)
+
+    # Also test on double
+    input_data = torch.rand(input_shape).double()
+    verify_model(test_fn(order=2), input_data=input_data)
+    verify_model(test_fn(order=3.5), input_data=input_data)
+    verify_model(test_fn(order=np.inf), input_data=input_data)
+    verify_model(test_fn(order=np.NINF), input_data=input_data)
+    verify_model(test_fn(order=0), input_data=input_data)
+
+
 class TestSetSpan:
     """test structural equal between translated / hand-crafted relay IR with 
span tagged."""
 

Reply via email to