This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f08165680e [Relax][PyTorch] Add support for norm op (#17841)
f08165680e is described below

commit f08165680e8b8696679d331e1f547de8839145cd
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Apr 17 11:23:33 2025 +0800

    [Relax][PyTorch] Add support for norm op (#17841)
    
    * Update fx_translator.py
    
    * Update base_fx_graph_translator.py
    
    * Update test_frontend_from_fx.py
    
    * Update base_fx_graph_translator.py
    
    * Update test_frontend_from_fx.py
    
    * Update base_fx_graph_translator.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
---
 .../frontend/torch/base_fx_graph_translator.py     |  34 ++++++
 python/tvm/relax/frontend/torch/fx_translator.py   |   1 +
 tests/python/relax/test_frontend_from_fx.py        | 129 +++++++++++++++++++++
 3 files changed, 164 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 6d880ab90d..4c9480b587 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -927,6 +927,40 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
         return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))
 
+    def _norm(self, node: fx.Node) -> relax.Var:
+        data = self.env[node.args[0]]
+        dtype = data.struct_info.dtype
+        order = node.args[1] if len(node.args) > 1 else node.kwargs.get("p", 2)
+        axis = node.args[2] if len(node.args) > 2 else None
+        keepdims = node.args[3] if len(node.args) > 3 else False
+
+        if order == float("inf"):
+            return self.block_builder.emit(
+                relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims)
+            )
+        elif order == float("-inf"):
+            return self.block_builder.emit(
+                relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims)
+            )
+        # frobenius_norm
+        elif order == "fro":
+            return self.block_builder.emit(
+                relax.op.sqrt(
+                    relax.op.sum(relax.op.multiply(data, data), axis=axis, 
keepdims=keepdims),
+                )
+            )
+        else:
+            reci_order = relax.const(1 / order, dtype=dtype)
+            order = relax.const(order, dtype=dtype)
+            return self.block_builder.emit(
+                relax.op.power(
+                    relax.op.sum(
+                        relax.op.power(relax.op.abs(data), order), axis=axis, 
keepdims=keepdims
+                    ),
+                    reci_order,
+                )
+            )
+
     def _prod(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 594344fef8..297529e8bf 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -728,6 +728,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "lerp": self._lerp,
             # statistical
             "mean": self._mean,
+            "norm": self._norm,
             "prod": self._prod,
             "std": self._std,
             "sum": self._sum,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index ee5a5c78c7..a962de8a32 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4513,5 +4513,134 @@ def test_narrow():
     verify_model(Narrow(), [([5, 3], "float32")], {}, Expected)
 
 
+def test_norm():
+
+    input_info = [([1, 3, 5, 3], "float32")]
+
+    class Norm(Module):
+        def __init__(self, p, dim=None, keepdim=False):
+            super().__init__()
+            self.p = p
+            self.dim = dim
+            self.keepdim = keepdim
+
+        def forward(self, x):
+            return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim)
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.max(R.abs(inp_0), 
axis=None, keepdims=False)
+                gv: R.Tensor((), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.min(R.abs(inp_0), 
axis=None, keepdims=False)
+                gv: R.Tensor((), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected3:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+                lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, 
R.const(2, "float32"))
+                lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, 
keepdims=False)
+                lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(0.5, 
"float32"))
+                gv: R.Tensor((), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected4:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+                lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, 
R.const(1.0, "float32"))
+                lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, 
keepdims=False)
+                lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(1.0, 
"float32"))
+                gv: R.Tensor((), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected5:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+                lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, 
R.const(-4, "float32"))
+                lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, 
keepdims=False)
+                lv3: R.Tensor((), dtype="float32") = R.power(lv2, 
R.const(-0.25, "float32"))
+                gv: R.Tensor((), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected6:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0)
+                lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, 
R.const(0.5, "float32"))
+                lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, 
keepdims=False)
+                lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(2, 
"float32"))
+                gv: R.Tensor((), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected7:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"),
+        ) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 3), dtype="float32") = 
R.multiply(inp_0, inp_0)
+                lv1: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, 
keepdims=False)
+                lv2: R.Tensor((), dtype="float32") = R.sqrt(lv1)
+                gv: R.Tensor((), dtype="float32") = lv2
+                R.output(gv)
+            return gv
+
+    norms = [
+        (float("inf"), None, False),
+        (float("-inf"), None, False),
+        (float(2), None, False),
+        (float(1.0), None, False),
+        (float(-4), None, True),
+        (float(0.5), None, True),
+        ("fro", None, False),
+    ]
+
+    for norm, expected in zip(
+        norms, [Expected1, Expected2, Expected3, Expected4, Expected5, 
Expected6, Expected7]
+    ):
+        p, dim, keepdim = norm
+        verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {}, 
expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to