This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cd7d64e914 Fix a bug caused by PyTorch instance_norm when the input
shape is [1,1,1,2] (#15683)
cd7d64e914 is described below
commit cd7d64e914f763f1f3abc9105142a3f60580c31b
Author: Haoyang <[email protected]>
AuthorDate: Fri Sep 8 15:28:00 2023 +0800
Fix a bug caused by PyTorch instance_norm when the input shape is [1,1,1,2]
(#15683)
* Fix an adaptive_max_pool1d operator conversion bug
* Fix an adaptive_max_pool1d operator conversion bug
* add tests for Fix an adaptive_max_pool1d operator conversion bug
* add tests for Fix an adaptive_max_pool1d operator conversion bug
* add tests for Fix an adaptive_max_pool1d operator conversion bug
* add tests for Fix an adaptive_max_pool1d operator conversion bug
* add tests for Fix an adaptive_max_pool1d operator conversion bug
* add tests for Fix an adaptive_max_pool1d operator conversion bug
* Fix an adaptive_max_pool1d operator conversion bug
* Fix an adaptive_max_pool1d operator conversion bug
* Add a TODO
* Add a TODO
* Add a TODO
* Fix the Bug caused by torch's instance_norm when input data is [1, 1, 1,
2]
* Add a unit test
* Fix the Bug caused by torch's instance_norm when input data is [1, 1, 1,
2]
* Fix the Bug caused by torch's instance_norm when input data is [1, 1, 1,
2]
* simplify the last fix
---
python/tvm/relay/frontend/pytorch.py | 2 +-
tests/python/frontend/pytorch/test_forward.py | 14 ++++++++++++++
2 files changed, 15 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 683b94dd92..9ddd04b5b4 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -4424,7 +4424,7 @@ def _create_typed_const(data, dtype):
dtype should be a TVM dtype"""
if dtype == "float64":
- typed_data = _expr.const(np.float64(data), dtype=dtype)
+ typed_data = _expr.const(np.asarray(data, dtype="float64"),
dtype=dtype)
elif dtype == "float32":
typed_data = _expr.const(np.float32(data), dtype=dtype)
elif dtype == "float16":
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 8c1cdbb0cf..9ee03512e7 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3449,6 +3449,20 @@ def test_forward_adaptive_max_pool1d():
verify_model(m.float().eval(), input_data=input_data)
[email protected]_gpu
+def test_forward_instance_norm():
+ """test_forward_instance_norm"""
+
+ class instance_norm(Module):
+ def forward(self, *args):
+ return torch.nn.functional.instance_norm(args[0],
use_input_stats=True)
+
+ m = instance_norm().float().eval()
+ input_data = torch.randn([1, 1, 1, 2], dtype=torch.float64)
+
+ verify_model(m.float().eval(), input_data=input_data)
+
+
@tvm.testing.uses_gpu
def test_forward_full_like():
"""test_forward_full_like"""