This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 59f5cbe support aten::type_as in the pytorch frontend (#5787)
59f5cbe is described below
commit 59f5cbe921cf329febcd9d6eff2df94d80f1c523
Author: Rand Xie <[email protected]>
AuthorDate: Fri Jun 12 21:52:45 2020 -0700
support aten::type_as in the pytorch frontend (#5787)
* support aten::type_as in the pytorch frontend
* use _convert_data_type to convert torch type to tvm type and add more
types in the type_as test
---
python/tvm/relay/frontend/pytorch.py | 9 +++++++
tests/python/frontend/pytorch/test_forward.py | 37 +++++++++++++++++++++++++++
2 files changed, 46 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index a9f4a7b..d2451cd 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1645,6 +1645,14 @@ def _list_len(prelude):
return _impl
+def _type_as():
+ def _impl(inputs, input_types):
+ assert len(inputs) == 2
+ assert len(input_types) == 2
+ return _op.cast(inputs[0], _convert_data_type(input_types[1]))
+ return _impl
+
+
def _add(prelude):
# add_ is overloaded for tensor add and list concat
def _impl(inputs, input_types):
@@ -1953,6 +1961,7 @@ def _get_convert_map(prelude):
"aten::stack" : _tensor_array_stack(prelude),
"aten::__getitem__" : _list_getitem(prelude),
"aten::len" : _list_len(prelude),
+ "aten::type_as" : _type_as(),
}
return convert_map
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 86fb409..f8fb57f 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -27,6 +27,7 @@ import torchvision
from tvm import relay
from tvm.contrib import graph_runtime
+from tvm.contrib.nvcc import have_fp16
from tvm.relay.testing.config import ctx_list
@@ -837,6 +838,41 @@ def test_forward_size():
input_data = torch.rand(input_shape).float()
verify_model(Size1().float().eval(), input_data=input_data)
+
+def test_type_as():
+ torch.set_grad_enabled(False)
+ input_shape = [1, 3]
+
+ def _create_module(dtype):
+ class TypeAs(Module):
+ def forward(self, *args):
+ expected_type_tensor = torch.zeros(1, 3, dtype=dtype)
+ return args[0].type_as(expected_type_tensor)
+
+ return TypeAs()
+
+ input_data = torch.randn(input_shape).float()
+ verify_model(_create_module(torch.float64), input_data=input_data)
+ verify_model(_create_module(torch.float32), input_data=input_data)
+ verify_model(_create_module(torch.int64), input_data=input_data)
+ verify_model(_create_module(torch.int32), input_data=input_data)
+ verify_model(_create_module(torch.int16), input_data=input_data)
+ verify_model(_create_module(torch.int8), input_data=input_data)
+
+ if torch.cuda.is_available():
+ check_fp16 = False
+ try:
+ # Only check half precision on supported hardwares.
+ if have_fp16(tvm.gpu(0).compute_version):
+ check_fp16 = True
+ except Exception as e:
+ # If GPU is not enabled in TVM, skip the fp16 test.
+ pass
+
+ if check_fp16:
+ verify_model(_create_module(torch.float16), input_data=input_data)
+
+
def test_forward_view():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
@@ -2575,6 +2611,7 @@ if __name__ == "__main__":
test_upsample()
test_forward_upsample3d()
test_to()
+ test_type_as()
test_forward_functional_pad()
test_forward_zero_pad2d()
test_forward_constant_pad1d()