This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c44088df05 [Relax][PyTorch] Fix `batch_norm.default` args handling in 
ExportedProgram frontend (#18486)
c44088df05 is described below

commit c44088df05ea9b3fec3d981c83b4341ec51939e9
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Nov 22 14:12:16 2025 +0900

    [Relax][PyTorch] Fix `batch_norm.default` args handling in ExportedProgram 
frontend (#18486)
    
    Properly handle args.
    cc @tlopex
---
 .../frontend/torch/exported_program_translator.py  | 27 ++++++---
 .../relax/test_frontend_from_exported_program.py   | 68 +++++++++++++++++++---
 2 files changed, 79 insertions(+), 16 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 64af72c457..1961898f76 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -96,15 +96,28 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         bias = self.env.get(node.args[2], relax.const(np.zeros(channel), 
dtype=dtype))
         running_mean = self.env.get(node.args[3], 
relax.const(np.zeros(channel), dtype=dtype))
         running_var = self.env.get(node.args[4], relax.const(np.ones(channel), 
dtype=dtype))
-        ignore_running_stats = (
-            node.args[5] if len(node.args) > 5 else 
node.kwargs.get("track_running_stats", True)
-        )
-        track_running_stats = not ignore_running_stats
-        momentum = node.args[6] if len(node.args) > 6 else 
node.kwargs.get("momentum", 0.1)
-        eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps", 
1e-05)
 
-        if track_running_stats:
+        # After torch.export decomposition, batch_norm shows up as
+        # _native_batch_norm_legit_* with signature (x, weight, bias, mean, 
var, momentum, eps).
+        target_name = getattr(node.target, "__name__", "")
+        if target_name.startswith("_native_batch_norm_legit_no_training"):
+            momentum = node.args[5] if len(node.args) > 5 else 
node.kwargs.get("momentum", 0.1)
+            eps = node.args[6] if len(node.args) > 6 else 
node.kwargs.get("eps", 1e-05)
+            training = False
+        elif target_name.startswith("_native_batch_norm_legit_functional"):
+            momentum = node.args[5] if len(node.args) > 5 else 
node.kwargs.get("momentum", 0.1)
+            eps = node.args[6] if len(node.args) > 6 else 
node.kwargs.get("eps", 1e-05)
             training = True
+        else:
+            ignore_running_stats = (
+                node.args[5] if len(node.args) > 5 else 
node.kwargs.get("track_running_stats", True)
+            )
+            track_running_stats = not ignore_running_stats
+            momentum = node.args[6] if len(node.args) > 6 else 
node.kwargs.get("momentum", 0.1)
+            eps = node.args[7] if len(node.args) > 7 else 
node.kwargs.get("eps", 1e-05)
+
+            if track_running_stats:
+                training = True
 
         return self.block_builder.emit(
             relax.op.nn.batch_norm(
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index a19c36ca22..01efb6b936 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1624,7 +1624,7 @@ def test_div_mode():
 
 
 def test_batchnorm2d():
-    class BatchNorm2d(Module):
+    class BatchNorm2d1(Module):
         def __init__(self):
             super().__init__()
             self.bn = torch.nn.BatchNorm2d(3)
@@ -1658,7 +1658,48 @@ def test_batchnorm2d():
                     epsilon=1e-05,
                     center=True,
                     scale=True,
-                    momentum=1e-05,
+                    momentum=0.1,
+                    training=False,
+                )
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    class BatchNorm2dCustom(Module):
+        def __init__(self):
+            super().__init__()
+            self.bn = torch.nn.BatchNorm2d(3, eps=0.001, momentum=0.01)
+
+        def forward(self, input):
+            return self.bn(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((3,), dtype="float32"),
+            w2: R.Tensor((3,), dtype="float32"),
+            w3: R.Tensor((3,), dtype="float32"),
+            w4: R.Tensor((3,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                ) = R.nn.batch_norm(
+                    input_1,
+                    w1,
+                    w2,
+                    w3,
+                    w4,
+                    axis=1,
+                    epsilon=0.001,
+                    center=True,
+                    scale=True,
+                    momentum=0.01,
                     training=False,
                 )
                 lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
@@ -1668,14 +1709,23 @@ def test_batchnorm2d():
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
 
-    model = BatchNorm2d().eval()
-    binding = {
-        "w1": model.bn.weight.detach().numpy(),
-        "w2": model.bn.bias.detach().numpy(),
-        "w3": model.bn.running_mean.detach().numpy(),
-        "w4": model.bn.running_var.detach().numpy(),
+    model_1 = BatchNorm2d1().eval()
+    binding_1 = {
+        "w1": model_1.bn.weight.detach().numpy(),
+        "w2": model_1.bn.bias.detach().numpy(),
+        "w3": model_1.bn.running_mean.detach().numpy(),
+        "w4": model_1.bn.running_var.detach().numpy(),
     }
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model_1, example_args, binding_1, expected1)
+
+    model_2 = BatchNorm2dCustom().eval()
+    binding_2 = {
+        "w1": model_2.bn.weight.detach().numpy(),
+        "w2": model_2.bn.bias.detach().numpy(),
+        "w3": model_2.bn.running_mean.detach().numpy(),
+        "w4": model_2.bn.running_var.detach().numpy(),
+    }
+    verify_model(model_2, example_args, binding_2, expected2)
 
 
 def test_adaptive_avgpool1d():

Reply via email to