This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a9955e55a7 [Relax][PyTorch] Add decomposed operator support for 
normalization (#18460)
a9955e55a7 is described below

commit a9955e55a7345b764db621e9c35f65451824cbd5
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Nov 16 14:06:17 2025 +0800

    [Relax][PyTorch] Add decomposed operator support for normalization (#18460)
    
    ## Related Issue
    
    - https://github.com/apache/tvm/pull/18401
    
    ## How
    
    This PR
    
    - added `_batch_norm_legit_no_stats`
    - added `_native_group_norm`
    - added `any.dims`
    - refctored `_reshape`
---
 .../frontend/torch/base_fx_graph_translator.py     |  6 +++
 .../frontend/torch/exported_program_translator.py  | 50 ++++++++++++++++++++++
 .../relax/test_frontend_from_exported_program.py   | 24 ++++++-----
 3 files changed, 69 insertions(+), 11 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 7b8c51895c..b03723cb91 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1848,6 +1848,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         args = self.retrieve_args(node)
         x = args[0]
         dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
+
+        # Skip identity reshape
+        current_shape = self.shape_of(x)
+        if list(current_shape) == list(dims):
+            return x
+
         return self.block_builder.emit(relax.op.reshape(x, dims))
 
     def _reshape_as(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 2a119e111b..63aba55a78 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -113,6 +113,31 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         training = False
         return self._batch_norm(node, training)
 
+    def _batch_norm_legit_no_stats(self, node: fx.Node) -> relax.Var:
+        import numpy as np
+
+        x = self.env[node.args[0]]
+        channel = int(self.shape_of(x)[1])
+        dtype = x.struct_info.dtype
+        weight = self.env.get(node.args[1], relax.const(np.ones(channel), 
dtype=dtype))
+        bias = self.env.get(node.args[2], relax.const(np.zeros(channel), 
dtype=dtype))
+        eps = node.args[5] if len(node.args) > 5 else node.kwargs.get("eps", 
1e-05)
+
+        # Determine axes for instance norm (all spatial dimensions after 
channel)
+        dim = len(self.shape_of(x))
+        axes = list(range(2, dim))
+
+        return self.block_builder.emit(
+            relax.op.nn.instance_norm(
+                x,
+                weight,
+                bias,
+                channel_axis=1,
+                axes=axes,
+                epsilon=eps,
+            )
+        )
+
     def _cross_entropy_default(self, node: fx.Node) -> relax.Expr:
         preds = self.env[node.args[0]]
         targets = self.env[node.args[1]]
@@ -141,6 +166,28 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             )
         )
 
+    def _native_group_norm(self, node: fx.Node) -> relax.Var:
+        # native_group_norm signature: (input, weight, bias, N, C, HxW, group, 
eps)
+        x = self.env[node.args[0]]
+        gamma = self.env.get(node.args[1], None) if len(node.args) > 1 else 
None
+        beta = self.env.get(node.args[2], None) if len(node.args) > 2 else None
+        # args[3] = N (batch size), args[4] = C (channels), args[5] = HxW 
(spatial size)
+        num_groups = node.args[6] if len(node.args) > 6 else 1
+        eps = node.args[7] if len(node.args) > 7 else 1e-05
+
+        dim = len(self.shape_of(x))
+        return self.block_builder.emit(
+            relax.op.nn.group_norm(
+                x,
+                gamma,
+                beta,
+                num_groups=num_groups,
+                channel_axis=1,
+                axes=list(range(2, dim)),
+                epsilon=eps,
+            )
+        )
+
     def _upsample_impl(
         self,
         x: relax.Expr,
@@ -963,6 +1010,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "_adaptive_avg_pool3d.default": self._adaptive_avg_pool3d,
             "_native_batch_norm_legit_functional.default": 
self._batch_norm_legit_functional,
             "_native_batch_norm_legit_no_training.default": 
self._batch_norm_legit_no_training,
+            "_native_batch_norm_legit.no_stats": 
self._batch_norm_legit_no_stats,
             "batch_norm.default": self._batch_norm_legit_no_training,
             "adaptive_avg_pool1d.default": self._adaptive_avg_pool1d,
             "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
@@ -988,6 +1036,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             ),
             "group_norm.default": self._group_norm,
             "instance_norm.default": self._instance_norm,
+            "native_group_norm.default": self._native_group_norm,
             "layer_norm.default": self._layer_norm,
             "linear.default": self._linear,
             "lstm.input": self._lstm,
@@ -1004,6 +1053,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "upsample_bicubic2d.vec": self._upsample_bicubic2d,
             # statistical
             "any.dim": self._any,
+            "any.dims": self._any,
             "mean.dim": self._mean,
             "prod.default": self._prod,
             "std.correction": self._std,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index f571ee1fd9..1b816432ce 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1514,12 +1514,10 @@ def test_isin():
             x: R.Tensor((10, 10), dtype="float32"), test_elements: 
R.Tensor((8,), dtype="float32")
         ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
             with R.dataflow():
-                lv: R.Tensor((10, 10, 1), dtype="float32") = R.expand_dims(x, 
axis=[-1])
-                lv1: R.Tensor((8,), dtype="float32") = 
R.reshape(test_elements, R.shape([8]))
-                lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1)
-                lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1], 
keepdims=False)
-                lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3, 
R.const(0.0, "float32"))
-                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv4,)
+                lv: R.Tensor((10, 10, 1), dtype="float32") = R.reshape(x, 
R.shape([10, 10, 1]))
+                lv1: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, 
test_elements)
+                lv2: R.Tensor((10, 10), dtype="bool") = R.max(lv1, axis=[-1], 
keepdims=False)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv2,)
                 R.output(gv)
             return gv
 
@@ -1527,7 +1525,7 @@ def test_isin():
         torch.randn(10, 10, dtype=torch.float32),
         torch.randn(8, dtype=torch.float32),
     )
-    verify_model(IsInModel(), example_args, {}, expected)
+    verify_model(IsInModel(), example_args, {}, expected, 
run_ep_decomposition=True)
 
 
 def test_div_mode():
@@ -3155,7 +3153,7 @@ def test_groupnorm():
         "w1": model.gn.weight.detach().numpy(),
         "w2": model.gn.bias.detach().numpy(),
     }
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
 
 def test_instancenorm2d():
@@ -3200,7 +3198,7 @@ def test_instancenorm2d():
         "w1": torch.ones(3).detach().numpy(),
         "w2": torch.zeros(3).detach().numpy(),
     }
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
 
 def test_layernorm():
@@ -5556,7 +5554,9 @@ def test_unwrap_unit_return_tuple():
 
     example_args = (torch.randn(256, 256, dtype=torch.float32),)
     exported_program = export(Identity(), args=example_args)
-    mod = from_exported_program(exported_program, 
unwrap_unit_return_tuple=True)
+    mod = from_exported_program(
+        exported_program, unwrap_unit_return_tuple=True, 
run_ep_decomposition=True
+    )
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
@@ -5586,7 +5586,9 @@ def test_no_bind_return_tuple():
         torch.randn(256, 256, dtype=torch.float32),
     )
     exported_program = export(Identity(), args=example_args)
-    mod = from_exported_program(exported_program, no_bind_return_tuple=True)
+    mod = from_exported_program(
+        exported_program, no_bind_return_tuple=True, run_ep_decomposition=True
+    )
     tvm.ir.assert_structural_equal(mod, Expected)
 
 

Reply via email to