This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5f1421dd0f [Bugfix][PyTorch] Support use_input_stats in instance_norm
(#14963)
5f1421dd0f is described below
commit 5f1421dd0f74485bb051ae298311081658195fb9
Author: Quanfeng Li <[email protected]>
AuthorDate: Sat May 27 18:52:17 2023 +0800
[Bugfix][PyTorch] Support use_input_stats in instance_norm (#14963)
---
python/tvm/relay/frontend/pytorch.py | 17 +++++++++++++++++
tests/python/frontend/pytorch/test_forward.py | 2 ++
2 files changed, 19 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 7eb13f3546..5eada41753 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1386,6 +1386,9 @@ class PyTorchOpConverter:
data = inputs[0]
data_type = input_types[0]
channels = self.infer_shape(data)
+ running_mean = inputs[3]
+ running_var = inputs[4]
+ use_input_stats = inputs[5]
if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2],
_expr.Expr):
scale = center = True
@@ -1402,6 +1405,20 @@ class PyTorchOpConverter:
beta = _create_typed_const(np.zeros([int(channels[1])]), data_type)
epsilon = float(inputs[7])
+
+ if not use_input_stats:
+ return _op.nn.batch_norm(
+ data,
+ gamma,
+ beta,
+ running_mean,
+ running_var,
+ axis=1,
+ epsilon=epsilon,
+ center=center,
+ scale=scale,
+ )[0]
+
return _op.nn.instance_norm(
data, gamma, beta, axis=1, epsilon=epsilon, center=center,
scale=scale
)
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 6186b7909c..c8972828ab 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1427,6 +1427,8 @@ def test_forward_instancenorm():
for ins_norm, inp in [
(torch.nn.InstanceNorm2d(16), inp_2d),
(torch.nn.InstanceNorm3d(16), inp_3d),
+ (torch.nn.InstanceNorm2d(16, track_running_stats=True), inp_2d),
+ (torch.nn.InstanceNorm3d(16, track_running_stats=True), inp_3d),
]:
verify_model(ins_norm.eval(), input_data=inp)