This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b4475b83aa [Bugfix] Fix batch_norm (#14857)
b4475b83aa is described below
commit b4475b83aa3f51e7fdc852ba6bd596ff4c8cd4ff
Author: Quanfeng Li <[email protected]>
AuthorDate: Thu May 18 14:36:15 2023 +0800
[Bugfix] Fix batch_norm (#14857)
[Bugfix][PyTorch] Fix batch_norm
---
python/tvm/relay/frontend/pytorch.py | 16 +++++++---------
tests/python/frontend/pytorch/test_forward.py | 17 +++++++++++++++++
2 files changed, 24 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 1ef8b6faee..e479dd097d 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1353,18 +1353,16 @@ class PyTorchOpConverter:
channels = self.infer_shape(data)
- if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2],
_expr.Expr):
- scale = center = True
- weight = inputs[1]
- beta = inputs[2]
- gamma = weight
+ scale = isinstance(inputs[1], _expr.Expr)
+ if scale:
+ gamma = inputs[1]
else:
- scale = center = False
-
- if not scale:
gamma = _create_typed_const(np.ones([int(channels[1])]), data_type)
- if not center:
+ center = isinstance(inputs[2], _expr.Expr)
+ if center:
+ beta = inputs[2]
+ else:
beta = _create_typed_const(np.zeros([int(channels[1])]), data_type)
moving_mean = inputs[3]
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index ffa37af331..7552f5cc61 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1383,9 +1383,26 @@ def test_forward_batchnorm():
inp_2d = torch.rand((1, 16, 10, 10))
inp_3d = torch.rand((1, 16, 10, 10, 10))
+ class BatchNorm(Module):
+ def __init__(self, weight, bias):
+ super().__init__()
+ self.weight = weight
+ self.bias = bias
+
+ def forward(self, *args):
+ return torch.nn.functional.batch_norm(
+ args[0],
+ running_mean=torch.zeros(args[0].shape[1]),
+ running_var=torch.ones(args[0].shape[1]),
+ weight=self.weight,
+ bias=self.bias,
+ )
+
for bn, inp in [(torch.nn.BatchNorm2d(16), inp_2d),
(torch.nn.BatchNorm3d(16), inp_3d)]:
init_weight(bn.eval())
verify_model(bn.eval(), input_data=inp)
+ verify_model(BatchNorm(bn.weight, None).eval(), input_data=inp)
+ verify_model(BatchNorm(bn.weight, bn.bias).eval(), input_data=inp)
@tvm.testing.uses_gpu