This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c44088df05 [Relax][PyTorch] Fix `batch_norm.default` args handling in
ExportedProgram frontend (#18486)
c44088df05 is described below
commit c44088df05ea9b3fec3d981c83b4341ec51939e9
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sat Nov 22 14:12:16 2025 +0900
[Relax][PyTorch] Fix `batch_norm.default` args handling in ExportedProgram
frontend (#18486)
Properly handle args.
cc @tlopex
---
.../frontend/torch/exported_program_translator.py | 27 ++++++---
.../relax/test_frontend_from_exported_program.py | 68 +++++++++++++++++++---
2 files changed, 79 insertions(+), 16 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 64af72c457..1961898f76 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -96,15 +96,28 @@ class ExportedProgramImporter(BaseFXGraphImporter):
bias = self.env.get(node.args[2], relax.const(np.zeros(channel),
dtype=dtype))
running_mean = self.env.get(node.args[3],
relax.const(np.zeros(channel), dtype=dtype))
running_var = self.env.get(node.args[4], relax.const(np.ones(channel),
dtype=dtype))
- ignore_running_stats = (
- node.args[5] if len(node.args) > 5 else
node.kwargs.get("track_running_stats", True)
- )
- track_running_stats = not ignore_running_stats
- momentum = node.args[6] if len(node.args) > 6 else
node.kwargs.get("momentum", 0.1)
- eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps",
1e-05)
- if track_running_stats:
+ # After torch.export decomposition, batch_norm shows up as
+ # _native_batch_norm_legit_* with signature (x, weight, bias, mean,
var, momentum, eps).
+ target_name = getattr(node.target, "__name__", "")
+ if target_name.startswith("_native_batch_norm_legit_no_training"):
+ momentum = node.args[5] if len(node.args) > 5 else
node.kwargs.get("momentum", 0.1)
+ eps = node.args[6] if len(node.args) > 6 else
node.kwargs.get("eps", 1e-05)
+ training = False
+ elif target_name.startswith("_native_batch_norm_legit_functional"):
+ momentum = node.args[5] if len(node.args) > 5 else
node.kwargs.get("momentum", 0.1)
+ eps = node.args[6] if len(node.args) > 6 else
node.kwargs.get("eps", 1e-05)
training = True
+ else:
+ ignore_running_stats = (
+ node.args[5] if len(node.args) > 5 else
node.kwargs.get("track_running_stats", True)
+ )
+ track_running_stats = not ignore_running_stats
+ momentum = node.args[6] if len(node.args) > 6 else
node.kwargs.get("momentum", 0.1)
+ eps = node.args[7] if len(node.args) > 7 else
node.kwargs.get("eps", 1e-05)
+
+ if track_running_stats:
+ training = True
return self.block_builder.emit(
relax.op.nn.batch_norm(
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index a19c36ca22..01efb6b936 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1624,7 +1624,7 @@ def test_div_mode():
def test_batchnorm2d():
- class BatchNorm2d(Module):
+ class BatchNorm2d1(Module):
def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm2d(3)
@@ -1658,7 +1658,48 @@ def test_batchnorm2d():
epsilon=1e-05,
center=True,
scale=True,
- momentum=1e-05,
+ momentum=0.1,
+ training=False,
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ class BatchNorm2dCustom(Module):
+ def __init__(self):
+ super().__init__()
+ self.bn = torch.nn.BatchNorm2d(3, eps=0.001, momentum=0.01)
+
+ def forward(self, input):
+ return self.bn(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((3,), dtype="float32"),
+ w2: R.Tensor((3,), dtype="float32"),
+ w3: R.Tensor((3,), dtype="float32"),
+ w4: R.Tensor((3,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ ) = R.nn.batch_norm(
+ input_1,
+ w1,
+ w2,
+ w3,
+ w4,
+ axis=1,
+ epsilon=0.001,
+ center=True,
+ scale=True,
+ momentum=0.01,
training=False,
)
lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
@@ -1668,14 +1709,23 @@ def test_batchnorm2d():
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- model = BatchNorm2d().eval()
- binding = {
- "w1": model.bn.weight.detach().numpy(),
- "w2": model.bn.bias.detach().numpy(),
- "w3": model.bn.running_mean.detach().numpy(),
- "w4": model.bn.running_var.detach().numpy(),
+ model_1 = BatchNorm2d1().eval()
+ binding_1 = {
+ "w1": model_1.bn.weight.detach().numpy(),
+ "w2": model_1.bn.bias.detach().numpy(),
+ "w3": model_1.bn.running_mean.detach().numpy(),
+ "w4": model_1.bn.running_var.detach().numpy(),
}
- verify_model(model, example_args, binding, expected1)
+ verify_model(model_1, example_args, binding_1, expected1)
+
+ model_2 = BatchNorm2dCustom().eval()
+ binding_2 = {
+ "w1": model_2.bn.weight.detach().numpy(),
+ "w2": model_2.bn.bias.detach().numpy(),
+ "w3": model_2.bn.running_mean.detach().numpy(),
+ "w4": model_2.bn.running_var.detach().numpy(),
+ }
+ verify_model(model_2, example_args, binding_2, expected2)
def test_adaptive_avgpool1d():