This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9530a8de1e [PyTorch] Add aten::new_zeros (#14747)
9530a8de1e is described below
commit 9530a8de1e2e6d150748ee6a7b45d39195785f50
Author: liquanfeng <[email protected]>
AuthorDate: Tue May 2 04:59:57 2023 +0800
[PyTorch] Add aten::new_zeros (#14747)
---
python/tvm/relay/frontend/pytorch.py | 17 +++++++++++++++++
tests/python/frontend/pytorch/test_forward.py | 7 +++++++
2 files changed, 24 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index c81e882526..9f561df7bc 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -832,6 +832,22 @@ class PyTorchOpConverter:
return out
+ def new_zeros(self, inputs, input_types):
+ data = inputs[1]
+
+ import torch
+
+ if not isinstance(data, (_expr.Expr, list, tuple, torch.Size)):
+ msg = "Data type %s could not be parsed in new_zeros op" %
(type(data))
+ raise AssertionError(msg)
+
+ if inputs[2] is not None:
+ dtype = _convert_dtype_value(inputs[2])
+ else:
+ # if dtype is None, use the dtype of the input tensor
+ dtype = self.infer_type(inputs[0])
+ return self.full_impl(data, 0, dtype)
+
def full(self, inputs, input_types):
data = inputs[0]
fill_value = inputs[1]
@@ -3755,6 +3771,7 @@ class PyTorchOpConverter:
"aten::zeros": self.zeros,
"aten::zero_": self.zero_,
"aten::zeros_like": self.zeros_like,
+ "aten::new_zeros": self.new_zeros,
"aten::new_ones": self.new_ones,
"aten::full": self.full,
"aten::full_like": self.full_like,
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index b5fcaaecae..897ebdec44 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3348,6 +3348,13 @@ def test_forward_zeros_like():
verify_model(ZerosLike3().float().eval(), input_data=input_data)
+def test_forward_new_zeros():
+ def test_func(x):
+ return x.new_zeros((2, 3))
+
+ verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()])
+
+
@tvm.testing.uses_gpu
def test_forward_full():
"""test_forward_full"""