This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1ef153fd1c [Relax][PyTorch] Refactor norm op for ExportedProgram 
importer (#17857)
1ef153fd1c is described below

commit 1ef153fd1c2346d7fc46cc45483b6582927aefcd
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Apr 23 10:59:28 2025 +0800

    [Relax][PyTorch] Refactor norm op for ExportedProgram importer (#17857)
    
    * Update base_fx_graph_translator.py
    
    * Update fx_translator.py
    
    * Update exported_program_translator.py
    
    * Update exported_program_translator.py
    
    * Update fx_translator.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update fx_translator.py
    
    * Update exported_program_translator.py
    
    * Update test_frontend_from_exported_program.py
    
    * Update test_frontend_from_fx.py
    
    * Update fx_translator.py
    
    * Update exported_program_translator.py
    
    * Update base_fx_graph_translator.py
    
    * Update exported_program_translator.py
    
    * Update exported_program_translator.py
    
    * Update fx_translator.py
    
    * Update fx_translator.py
    
    * Update base_fx_graph_translator.py
    
    * Update exported_program_translator.py
---
 .../frontend/torch/base_fx_graph_translator.py     |  50 ++-------
 .../frontend/torch/exported_program_translator.py  |   2 +-
 .../relax/test_frontend_from_exported_program.py   | 112 +++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        |  19 ++--
 4 files changed, 131 insertions(+), 52 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 20556167c1..a89726495e 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -417,42 +417,6 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return self.block_builder.emit(relax.op.subtract(rhs, lhs))
 
-    ########## Linear Algebra ##########
-
-    def _linalg_vector_norm(self, node: fx.Node) -> relax.Var:
-
-        args = self.retrieve_args(node)
-
-        data = args[0]
-        # Default ord=2 if not supplied
-        ord_val = args[1] if len(args) > 1 else 2.0
-        dim = args[2] if len(args) > 2 else None
-        keepdim = args[3] if len(args) > 3 else False
-
-        # If ord_val is a Python float/int, wrap it in a Relax const
-        # so that it matches data's dtype.
-        dtype = data.struct_info.dtype
-        ord_expr = (
-            ord_val if isinstance(ord_val, relax.Expr) else 
relax.const(float(ord_val), dtype)
-        )
-        # Reciprocal
-        reci_expr = (
-            relax.op.divide(relax.const(1.0, dtype), ord_expr)
-            if isinstance(ord_val, relax.Expr)
-            else relax.const(1.0 / float(ord_val), dtype)
-        )
-
-        # abs(data)
-        abs_data = self.block_builder.emit(relax.op.abs(data))
-        # abs_data^ord
-        abs_data_pow = self.block_builder.emit(relax.op.power(abs_data, 
ord_expr))
-        # sum over dim
-        reduced = self.block_builder.emit(relax.op.sum(abs_data_pow, dim, 
keepdims=keepdim))
-        # (sum(...))^(1/ord)
-        norm_val = self.block_builder.emit(relax.op.power(reduced, reci_expr))
-
-        return norm_val
-
     ########## Neural Network ##########
 
     def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
@@ -980,16 +944,22 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         elif order == "fro":
             return self.block_builder.emit(
                 relax.op.sqrt(
-                    relax.op.sum(relax.op.multiply(data, data), axis=axis, 
keepdims=keepdims),
+                    relax.op.sum(relax.op.multiply(data, data), axis=axis, 
keepdims=keepdims)
                 )
             )
         else:
-            reci_order = relax.const(1 / order, dtype=dtype)
-            order = relax.const(order, dtype=dtype)
+            ord_expr = (
+                order if isinstance(order, relax.Expr) else 
relax.const(float(order), dtype=dtype)
+            )
+            reci_order = (
+                relax.op.divide(relax.const(1.0, dtype), ord_expr)
+                if isinstance(order, relax.Expr)
+                else relax.const(1.0 / order, dtype=dtype)
+            )
             return self.block_builder.emit(
                 relax.op.power(
                     relax.op.sum(
-                        relax.op.power(relax.op.abs(data), order), axis=axis, 
keepdims=keepdims
+                        relax.op.power(relax.op.abs(data), ord_expr), 
axis=axis, keepdims=keepdims
                     ),
                     reci_order,
                 )
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index f38f353a9e..cdf0c46bb5 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -369,7 +369,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "__xor__.Tensor": self._binary_op(relax.op.bitwise_xor, 
operator.xor),
             "__xor__.Scalar": self._binary_op(relax.op.bitwise_xor, 
operator.xor),
             # linear algebra
-            "linalg_vector_norm.default": self._linalg_vector_norm,
+            "linalg_vector_norm.default": self._norm,
             # neural network
             "_native_batch_norm_legit_functional.default": 
self._batch_norm_legit_functional,
             "_native_batch_norm_legit_no_training.default": 
self._batch_norm_legit_no_training,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index a386a989f0..c6ead5aacc 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4379,6 +4379,118 @@ def test_narrow():
     verify_model(Narrow(), example_args, {}, Expected)
 
 
+def test_norm():
+    class Norm(Module):
+        def __init__(self, p, dim=None, keepdim=False):
+            super().__init__()
+            self.p = p
+            self.dim = dim
+            self.keepdim = keepdim
+
+        def forward(self, x):
+            return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim)
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.max(R.abs(inp_0), 
axis=None, keepdims=False)
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.min(R.abs(inp_0), 
axis=None, keepdims=False)
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected3:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+                lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, 
R.const(2, "float32"))
+                lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, 
keepdims=False)
+                lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(0.5, 
"float32"))
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected4:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+                lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, 
R.const(1.0, "float32"))
+                lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, 
keepdims=False)
+                lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(1.0, 
"float32"))
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected5:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+                lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, 
R.const(-4.0, "float32"))
+                lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, 
axis=None, keepdims=True)
+                lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(
+                    lv2, R.const(-0.25, "float32")
+                )
+                gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected6:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+                lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, 
R.const(0.5, "float32"))
+                lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, 
axis=None, keepdims=True)
+                lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(lv2, 
R.const(2.0, "float32"))
+                gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    norms = [
+        ((float("inf"), None, False), Expected1),
+        ((float("-inf"), None, False), Expected2),
+        ((float(2), None, False), Expected3),
+        ((float(1.0), None, False), Expected4),
+        ((float(-4), None, True), Expected5),
+        ((float(0.5), None, True), Expected6),
+    ]
+
+    example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),)
+
+    for (p, dim, keepdim), expected in norms:
+        verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {}, 
expected)
+
+
 def test_eye():
     class Eye1(Module):
         def forward(self, input):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 2d27fa1f59..f21cde6df2 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4947,19 +4947,16 @@ def test_norm():
             return gv
 
     norms = [
-        (float("inf"), None, False),
-        (float("-inf"), None, False),
-        (float(2), None, False),
-        (float(1.0), None, False),
-        (float(-4), None, True),
-        (float(0.5), None, True),
-        ("fro", None, False),
+        ((float("inf"), None, False), Expected1),
+        ((float("-inf"), None, False), Expected2),
+        ((float(2), None, False), Expected3),
+        ((float(1.0), None, False), Expected4),
+        ((float(-4), None, True), Expected5),
+        ((float(0.5), None, True), Expected6),
+        (("fro", None, False), Expected7),
     ]
 
-    for norm, expected in zip(
-        norms, [Expected1, Expected2, Expected3, Expected4, Expected5, 
Expected6, Expected7]
-    ):
-        p, dim, keepdim = norm
+    for (p, dim, keepdim), expected in norms:
         verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {}, 
expected)
 
 

Reply via email to