This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d5d3d81fe0 [Relax][PyTorch] Fix batch normalization training mode 
correctness (#18518)
d5d3d81fe0 is described below

commit d5d3d81fe09990b959a3b9db46ab12707095617c
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Fri Nov 28 02:34:10 2025 +0800

    [Relax][PyTorch] Fix batch normalization training mode correctness (#18518)
    
    ## Why
    
    Batch normalization in training mode would threw away updated
    statistics.
    
    ## How
    
    batch_norm(...) → keep all 3 elements, pad to 5 for PyTorch
---
 .../frontend/torch/exported_program_translator.py  | 34 ++++++----
 .../relax/test_frontend_from_exported_program.py   | 76 ++++++++++++++++++++++
 2 files changed, 99 insertions(+), 11 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 7af8774ee3..1f60d02a79 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -116,7 +116,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
 
     ########## Neural Network ##########
 
-    def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
+    def _batch_norm(self, node: fx.Node, training: bool, return_tuple: bool = 
False) -> relax.Var:
         import numpy as np
 
         x = self.env[node.args[0]]
@@ -149,7 +149,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             if track_running_stats:
                 training = True
 
-        return self.block_builder.emit(
+        bn_result = self.block_builder.emit(
             relax.op.nn.batch_norm(
                 data=x,
                 gamma=weight,
@@ -160,21 +160,33 @@ class ExportedProgramImporter(BaseFXGraphImporter):
                 epsilon=eps,
                 momentum=momentum,
                 training=training,
-            )[0]
+            )
         )
 
+        if return_tuple:
+            return bn_result
+        else:
+            # Return only the output tensor (for backward compatibility)
+            return self.block_builder.emit(bn_result[0])
+
     def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var:
         # This method is called for batch_norm in training mode
-        # TODO does not have correctness!
-        # TODO we need to store the running mean and variance returned by the
-        # previous call to batch_norm and pass it again
-        training = True
-        return self._batch_norm(node, training)
+        bn_tuple = self._batch_norm(node, training=True, return_tuple=True)
+
+        x = self.env[node.args[0]]
+        channel = int(self.shape_of(x)[1])
+        dtype = x.struct_info.dtype
+
+        output = self.block_builder.emit(bn_tuple[0])
+        new_running_mean = self.block_builder.emit(bn_tuple[1])
+        reserve = 
self.block_builder.emit(relax.op.zeros(relax.ShapeExpr([channel]), dtype))
+
+        return self.block_builder.emit(
+            relax.Tuple([output, new_running_mean, reserve, reserve, reserve])
+        )
 
     def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
-        # This method is called for batch_norm in eval mode
-        training = False
-        return self._batch_norm(node, training)
+        return self._batch_norm(node, training=False, return_tuple=False)
 
     def _batch_norm_legit_no_stats(self, node: fx.Node) -> relax.Var:
         import numpy as np
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 93218190fc..31743c2d12 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1803,6 +1803,82 @@ def test_batchnorm2d():
     }
     verify_model(model_2, example_args, binding_2, expected2)
 
+    class BatchNorm2dTraining(Module):
+        def __init__(self):
+            super().__init__()
+            self.bn = torch.nn.BatchNorm2d(3, track_running_stats=True)
+
+        def forward(self, input):
+            return self.bn(input)
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            input_1: R.Tensor((2, 3, 4, 4), dtype="float32"),
+            w1: R.Tensor((3,), dtype="float32"),
+            w2: R.Tensor((3,), dtype="float32"),
+            w3: R.Tensor((3,), dtype="float32"),
+            w4: R.Tensor((3,), dtype="float32"),
+        ) -> R.Tuple(
+            R.Tensor((3,), dtype="float32"),
+            R.Tensor((3,), dtype="float32"),
+            R.Tensor((), dtype="int64"),
+            R.Tensor((2, 3, 4, 4), dtype="float32"),
+        ):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"), 
R.const(1, "int64"))
+                lv1: R.Tuple(
+                    R.Tensor((2, 3, 4, 4), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                ) = R.nn.batch_norm(
+                    input_1,
+                    w1,
+                    w2,
+                    w3,
+                    w4,
+                    axis=1,
+                    epsilon=0.1,
+                    center=True,
+                    scale=True,
+                    momentum=1.0,
+                    training=True,
+                )
+                lv2: R.Tensor((2, 3, 4, 4), dtype="float32") = lv1[0]
+                lv3: R.Tensor((3,), dtype="float32") = lv1[1]
+                lv4: R.Tensor((3,), dtype="float32") = R.zeros(R.shape([3]), 
dtype="float32")
+                lv5: R.Tuple(
+                    R.Tensor((2, 3, 4, 4), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                ) = (lv2, lv3, lv4, lv4, lv4)
+                lv6: R.Tensor((2, 3, 4, 4), dtype="float32") = lv5[0]
+                lv7: R.Tensor((3,), dtype="float32") = lv5[3]
+                lv8: R.Tensor((3,), dtype="float32") = lv5[4]
+                gv: R.Tuple(
+                    R.Tensor((3,), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                    R.Tensor((), dtype="int64"),
+                    R.Tensor((2, 3, 4, 4), dtype="float32"),
+                ) = (lv7, lv8, lv, lv6)
+                R.output(gv)
+            return gv
+
+    example_args_train = (torch.randn(2, 3, 4, 4, dtype=torch.float32),)
+
+    model_3 = BatchNorm2dTraining()
+    model_3.train()  # Set to training mode
+    binding_3 = {
+        "w1": model_3.bn.weight.detach().numpy(),
+        "w2": model_3.bn.bias.detach().numpy(),
+        "w3": model_3.bn.running_mean.detach().numpy(),
+        "w4": model_3.bn.running_var.detach().numpy(),
+    }
+    verify_model(model_3, example_args_train, binding_3, expected3)
+
 
 def test_adaptive_avgpool1d():
     class AdaptiveAvgPool1d0(torch.nn.Module):

Reply via email to