This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9530a8de1e [PyTorch]  Add aten::new_zeros (#14747)
9530a8de1e is described below

commit 9530a8de1e2e6d150748ee6a7b45d39195785f50
Author: liquanfeng <[email protected]>
AuthorDate: Tue May 2 04:59:57 2023 +0800

    [PyTorch]  Add aten::new_zeros (#14747)
---
 python/tvm/relay/frontend/pytorch.py          | 17 +++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py |  7 +++++++
 2 files changed, 24 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index c81e882526..9f561df7bc 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -832,6 +832,22 @@ class PyTorchOpConverter:
 
         return out
 
+    def new_zeros(self, inputs, input_types):
+        data = inputs[1]
+
+        import torch
+
+        if not isinstance(data, (_expr.Expr, list, tuple, torch.Size)):
+            msg = "Data type %s could not be parsed in new_zeros op" % 
(type(data))
+            raise AssertionError(msg)
+
+        if inputs[2] is not None:
+            dtype = _convert_dtype_value(inputs[2])
+        else:
+            # if dtype is None, use the dtype of the input tensor
+            dtype = self.infer_type(inputs[0])
+        return self.full_impl(data, 0, dtype)
+
     def full(self, inputs, input_types):
         data = inputs[0]
         fill_value = inputs[1]
@@ -3755,6 +3771,7 @@ class PyTorchOpConverter:
             "aten::zeros": self.zeros,
             "aten::zero_": self.zero_,
             "aten::zeros_like": self.zeros_like,
+            "aten::new_zeros": self.new_zeros,
             "aten::new_ones": self.new_ones,
             "aten::full": self.full,
             "aten::full_like": self.full_like,
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index b5fcaaecae..897ebdec44 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3348,6 +3348,13 @@ def test_forward_zeros_like():
     verify_model(ZerosLike3().float().eval(), input_data=input_data)
 
 
+def test_forward_new_zeros():
+    def test_func(x):
+        return x.new_zeros((2, 3))
+
+    verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()])
+
+
 @tvm.testing.uses_gpu
 def test_forward_full():
     """test_forward_full"""

Reply via email to