This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b4475b83aa [Bugfix] Fix batch_norm (#14857)
b4475b83aa is described below

commit b4475b83aa3f51e7fdc852ba6bd596ff4c8cd4ff
Author: Quanfeng Li <[email protected]>
AuthorDate: Thu May 18 14:36:15 2023 +0800

    [Bugfix] Fix batch_norm (#14857)
    
    [Bugfix][PyTorch] Fix batch_norm
---
 python/tvm/relay/frontend/pytorch.py          | 16 +++++++---------
 tests/python/frontend/pytorch/test_forward.py | 17 +++++++++++++++++
 2 files changed, 24 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 1ef8b6faee..e479dd097d 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1353,18 +1353,16 @@ class PyTorchOpConverter:
 
         channels = self.infer_shape(data)
 
-        if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2], 
_expr.Expr):
-            scale = center = True
-            weight = inputs[1]
-            beta = inputs[2]
-            gamma = weight
+        scale = isinstance(inputs[1], _expr.Expr)
+        if scale:
+            gamma = inputs[1]
         else:
-            scale = center = False
-
-        if not scale:
             gamma = _create_typed_const(np.ones([int(channels[1])]), data_type)
 
-        if not center:
+        center = isinstance(inputs[2], _expr.Expr)
+        if center:
+            beta = inputs[2]
+        else:
             beta = _create_typed_const(np.zeros([int(channels[1])]), data_type)
 
         moving_mean = inputs[3]
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index ffa37af331..7552f5cc61 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1383,9 +1383,26 @@ def test_forward_batchnorm():
     inp_2d = torch.rand((1, 16, 10, 10))
     inp_3d = torch.rand((1, 16, 10, 10, 10))
 
+    class BatchNorm(Module):
+        def __init__(self, weight, bias):
+            super().__init__()
+            self.weight = weight
+            self.bias = bias
+
+        def forward(self, *args):
+            return torch.nn.functional.batch_norm(
+                args[0],
+                running_mean=torch.zeros(args[0].shape[1]),
+                running_var=torch.ones(args[0].shape[1]),
+                weight=self.weight,
+                bias=self.bias,
+            )
+
     for bn, inp in [(torch.nn.BatchNorm2d(16), inp_2d), 
(torch.nn.BatchNorm3d(16), inp_3d)]:
         init_weight(bn.eval())
         verify_model(bn.eval(), input_data=inp)
+        verify_model(BatchNorm(bn.weight, None).eval(), input_data=inp)
+        verify_model(BatchNorm(bn.weight, bn.bias).eval(), input_data=inp)
 
 
 @tvm.testing.uses_gpu

Reply via email to