This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c5759738db [Relax][Onnx][BatchNorm] Pass momentum and training_mode
into BatchNorm Operator (#18704)
c5759738db is described below
commit c5759738dba668509cce383d5c112f593c0806ed
Author: Nguyen Duy Loc <[email protected]>
AuthorDate: Mon Feb 2 23:49:13 2026 +0700
[Relax][Onnx][BatchNorm] Pass momentum and training_mode into BatchNorm
Operator (#18704)
### Description
- Onnx model have training_mode atrr = False, but Relax model after
convert have training = True
- Momentum values in Relax module are not the same as onnx model
### Steps to Reproduce
<img width="600" height="400" alt="BatchNorm"
src="https://github.com/user-attachments/assets/2f0ca26b-e83b-4ab8-ab06-a537802af6de"
/>
- Relax model:
```
class Module:
def main(X: R.Tensor((2, 3, 4, 4), dtype="float32")) -> R.Tensor((2, 3,
4, 4), dtype="float32"):
R.func_attr({"num_input": 1})
with R.dataflow():
lv: R.Tuple(R.Tensor((2, 3, 4, 4), dtype="float32"),
R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")) =
R.nn.batch_norm(X, metadata["relax.expr.Constant"][0],
metadata["relax.expr.Constant"][1], metadata["relax.expr.Constant"][2],
metadata["relax.expr.Constant"][3], axis=1, epsilon=9.9999997473787516e-06,
center=True, scale=True, momentum=0.10000000000000001, training=True)
lv1: R.Tensor((2, 3, 4, 4), dtype="float32") = lv[0]
lv2: R.Tensor((3,), dtype="float32") = lv[1]
lv3: R.Tensor((3,), dtype="float32") = lv[2]
gv: R.Tensor((2, 3, 4, 4), dtype="float32") = lv1
R.output(gv)
return gv
```
### Resolved
- Get Attributes and Pass momentum/training_mode with default value into
BatchNorm Operator
- Fixed: #18703
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 12 +++++++++++-
1 file changed, 11 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 784be639dd..61ab45d308 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2435,8 +2435,18 @@ class BatchNormalization(OnnxOpConverter):
mean = inputs[3]
var = inputs[4]
epsilon = attr.get("epsilon", 1e-05)
+ momentum = attr.get("momentum", 0.9)
+ training_mode = attr.get("training_mode", 0)
return relax.op.nn.batch_norm(
- data, gamma=scale, beta=bias, moving_mean=mean, moving_var=var,
epsilon=epsilon, axis=1
+ data,
+ gamma=scale,
+ beta=bias,
+ moving_mean=mean,
+ moving_var=var,
+ axis=1,
+ epsilon=epsilon,
+ momentum=momentum,
+ training=training_mode,
)