This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5f1421dd0f [Bugfix][PyTorch] Support use_input_stats in instance_norm 
(#14963)
5f1421dd0f is described below

commit 5f1421dd0f74485bb051ae298311081658195fb9
Author: Quanfeng Li <[email protected]>
AuthorDate: Sat May 27 18:52:17 2023 +0800

    [Bugfix][PyTorch] Support use_input_stats in instance_norm (#14963)
---
 python/tvm/relay/frontend/pytorch.py          | 17 +++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py |  2 ++
 2 files changed, 19 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 7eb13f3546..5eada41753 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1386,6 +1386,9 @@ class PyTorchOpConverter:
         data = inputs[0]
         data_type = input_types[0]
         channels = self.infer_shape(data)
+        running_mean = inputs[3]
+        running_var = inputs[4]
+        use_input_stats = inputs[5]
 
         if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2], 
_expr.Expr):
             scale = center = True
@@ -1402,6 +1405,20 @@ class PyTorchOpConverter:
             beta = _create_typed_const(np.zeros([int(channels[1])]), data_type)
 
         epsilon = float(inputs[7])
+
+        if not use_input_stats:
+            return _op.nn.batch_norm(
+                data,
+                gamma,
+                beta,
+                running_mean,
+                running_var,
+                axis=1,
+                epsilon=epsilon,
+                center=center,
+                scale=scale,
+            )[0]
+
         return _op.nn.instance_norm(
             data, gamma, beta, axis=1, epsilon=epsilon, center=center, 
scale=scale
         )
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 6186b7909c..c8972828ab 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1427,6 +1427,8 @@ def test_forward_instancenorm():
     for ins_norm, inp in [
         (torch.nn.InstanceNorm2d(16), inp_2d),
         (torch.nn.InstanceNorm3d(16), inp_3d),
+        (torch.nn.InstanceNorm2d(16, track_running_stats=True), inp_2d),
+        (torch.nn.InstanceNorm3d(16, track_running_stats=True), inp_3d),
     ]:
         verify_model(ins_norm.eval(), input_data=inp)
 

Reply via email to