This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d5d3d81fe0 [Relax][PyTorch] Fix batch normalization training mode
correctness (#18518)
d5d3d81fe0 is described below
commit d5d3d81fe09990b959a3b9db46ab12707095617c
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Fri Nov 28 02:34:10 2025 +0800
[Relax][PyTorch] Fix batch normalization training mode correctness (#18518)
## Why
Batch normalization in training mode would threw away updated
statistics.
## How
batch_norm(...) → keep all 3 elements, pad to 5 for PyTorch
---
.../frontend/torch/exported_program_translator.py | 34 ++++++----
.../relax/test_frontend_from_exported_program.py | 76 ++++++++++++++++++++++
2 files changed, 99 insertions(+), 11 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 7af8774ee3..1f60d02a79 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -116,7 +116,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
########## Neural Network ##########
- def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
+ def _batch_norm(self, node: fx.Node, training: bool, return_tuple: bool =
False) -> relax.Var:
import numpy as np
x = self.env[node.args[0]]
@@ -149,7 +149,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
if track_running_stats:
training = True
- return self.block_builder.emit(
+ bn_result = self.block_builder.emit(
relax.op.nn.batch_norm(
data=x,
gamma=weight,
@@ -160,21 +160,33 @@ class ExportedProgramImporter(BaseFXGraphImporter):
epsilon=eps,
momentum=momentum,
training=training,
- )[0]
+ )
)
+ if return_tuple:
+ return bn_result
+ else:
+ # Return only the output tensor (for backward compatibility)
+ return self.block_builder.emit(bn_result[0])
+
def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var:
# This method is called for batch_norm in training mode
- # TODO does not have correctness!
- # TODO we need to store the running mean and variance returned by the
- # previous call to batch_norm and pass it again
- training = True
- return self._batch_norm(node, training)
+ bn_tuple = self._batch_norm(node, training=True, return_tuple=True)
+
+ x = self.env[node.args[0]]
+ channel = int(self.shape_of(x)[1])
+ dtype = x.struct_info.dtype
+
+ output = self.block_builder.emit(bn_tuple[0])
+ new_running_mean = self.block_builder.emit(bn_tuple[1])
+ reserve =
self.block_builder.emit(relax.op.zeros(relax.ShapeExpr([channel]), dtype))
+
+ return self.block_builder.emit(
+ relax.Tuple([output, new_running_mean, reserve, reserve, reserve])
+ )
def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
- # This method is called for batch_norm in eval mode
- training = False
- return self._batch_norm(node, training)
+ return self._batch_norm(node, training=False, return_tuple=False)
def _batch_norm_legit_no_stats(self, node: fx.Node) -> relax.Var:
import numpy as np
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 93218190fc..31743c2d12 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1803,6 +1803,82 @@ def test_batchnorm2d():
}
verify_model(model_2, example_args, binding_2, expected2)
+ class BatchNorm2dTraining(Module):
+ def __init__(self):
+ super().__init__()
+ self.bn = torch.nn.BatchNorm2d(3, track_running_stats=True)
+
+ def forward(self, input):
+ return self.bn(input)
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ input_1: R.Tensor((2, 3, 4, 4), dtype="float32"),
+ w1: R.Tensor((3,), dtype="float32"),
+ w2: R.Tensor((3,), dtype="float32"),
+ w3: R.Tensor((3,), dtype="float32"),
+ w4: R.Tensor((3,), dtype="float32"),
+ ) -> R.Tuple(
+ R.Tensor((3,), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ R.Tensor((), dtype="int64"),
+ R.Tensor((2, 3, 4, 4), dtype="float32"),
+ ):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="int64") = R.add(R.const(0, "int64"),
R.const(1, "int64"))
+ lv1: R.Tuple(
+ R.Tensor((2, 3, 4, 4), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ ) = R.nn.batch_norm(
+ input_1,
+ w1,
+ w2,
+ w3,
+ w4,
+ axis=1,
+ epsilon=0.1,
+ center=True,
+ scale=True,
+ momentum=1.0,
+ training=True,
+ )
+ lv2: R.Tensor((2, 3, 4, 4), dtype="float32") = lv1[0]
+ lv3: R.Tensor((3,), dtype="float32") = lv1[1]
+ lv4: R.Tensor((3,), dtype="float32") = R.zeros(R.shape([3]),
dtype="float32")
+ lv5: R.Tuple(
+ R.Tensor((2, 3, 4, 4), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ ) = (lv2, lv3, lv4, lv4, lv4)
+ lv6: R.Tensor((2, 3, 4, 4), dtype="float32") = lv5[0]
+ lv7: R.Tensor((3,), dtype="float32") = lv5[3]
+ lv8: R.Tensor((3,), dtype="float32") = lv5[4]
+ gv: R.Tuple(
+ R.Tensor((3,), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ R.Tensor((), dtype="int64"),
+ R.Tensor((2, 3, 4, 4), dtype="float32"),
+ ) = (lv7, lv8, lv, lv6)
+ R.output(gv)
+ return gv
+
+ example_args_train = (torch.randn(2, 3, 4, 4, dtype=torch.float32),)
+
+ model_3 = BatchNorm2dTraining()
+ model_3.train() # Set to training mode
+ binding_3 = {
+ "w1": model_3.bn.weight.detach().numpy(),
+ "w2": model_3.bn.bias.detach().numpy(),
+ "w3": model_3.bn.running_mean.detach().numpy(),
+ "w4": model_3.bn.running_var.detach().numpy(),
+ }
+ verify_model(model_3, example_args_train, binding_3, expected3)
+
def test_adaptive_avgpool1d():
class AdaptiveAvgPool1d0(torch.nn.Module):