This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new b1364eb  [PYTORCH]Take, Topk op support (#5332)
b1364eb is described below

commit b1364ebbedb6bf540d1d2610d772ac441e2f7cb5
Author: Samuel <siju.sam...@huawei.com>
AuthorDate: Wed Apr 15 15:48:03 2020 +0530

    [PYTORCH]Take, Topk op support (#5332)
    
    * [PYTORCH]take, topk op support
    
    * Ci Failure fix
---
 python/tvm/relay/frontend/pytorch.py          | 35 ++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py | 57 +++++++++++++++++++++++++++
 2 files changed, 92 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 38a811d..0acebe4 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -272,6 +272,39 @@ def _select():
         return _op.transform.take(data, index, axis=dim)
     return _impl
 
+def _take():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        import torch
+
+        if isinstance(inputs[1], _expr.Var):
+            indices = _op.cast(inputs[1], "int32")
+        elif isinstance(inputs[1], torch.Tensor):
+            indices = _wrap_const(inputs[1].numpy())
+        else:
+            msg = "Data type %s could not be parsed in take operator." % 
(type(inputs[1]))
+            raise AssertionError(msg)
+
+        return _op.transform.take(data, indices=indices)
+    return _impl
+
+def _topk():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        k = int(inputs[1])
+        axis = int(inputs[2])
+        is_ascend = not bool(inputs[3])
+        sort = bool(inputs[4])
+
+        if not sort:
+            msg = "Currently supports only sorted output for topk operator."
+            raise AssertionError(msg)
+
+        outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, 
ret_type="both")
+
+        return outs[0], outs[1]
+    return _impl
+
 def _reciprocal():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -1416,6 +1449,8 @@ def _get_convert_map(prelude):
         "aten::split"                           : _split(),
         "aten::split_with_sizes"                : _split_with_sizes(),
         "aten::select"                          : _select(),
+        "aten::take"                            : _take(),
+        "aten::topk"                            : _topk(),
         "aten::relu"                            : _relu(),
         "aten::relu_"                           : _relu(),
         "aten::prelu"                           : _prelu(),
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index d9d280f..c562fce 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1545,6 +1545,61 @@ def test_forward_round():
     verify_model(Round1().float().eval(), input_data=input_data)
 
 
+def test_forward_take():
+    torch.set_grad_enabled(False)
+    class Take1(Module):
+        def forward(self, *args):
+            indices = torch.tensor([[0,0],[1,0]])
+            if torch.cuda.is_available():
+                indices = indices.cuda()
+            return torch.take(args[0], indices)
+
+    class Take2(Module):
+        def forward(self, *args):
+            return torch.take(args[0], args[1])
+
+    input_data = torch.tensor([[1,2],[3,4]])
+    verify_model(Take1().float().eval(), input_data=input_data)
+    indices = torch.tensor([[0,0],[1,0]])
+    verify_model(Take2().float().eval(), input_data=[input_data, indices])
+
+
+def test_forward_topk():
+    torch.set_grad_enabled(False)
+    class Topk1(Module):
+        def forward(self, *args):
+            return torch.topk(args[0], k=3)
+
+    class Topk2(Module):
+        def forward(self, *args):
+            return torch.topk(args[0], k=3, dim=-2)
+
+    class Topk3(Module):
+        def forward(self, *args):
+            return torch.topk(args[0], k=3, dim=3)
+
+    class Topk4(Module):
+        def forward(self, *args):
+            return torch.topk(args[0], k=3, largest=True)
+
+    class Topk5(Module):
+        def forward(self, *args):
+            return torch.topk(args[0], k=3, largest=False)
+
+    class Topk6(Module):
+        def forward(self, *args):
+            return torch.topk(args[0], k=3, sorted=True)
+
+    input_shape = [1, 3, 10, 10]
+    input_data = torch.rand(input_shape).float()
+    verify_model(Topk1().float().eval(), input_data=input_data)
+    verify_model(Topk2().float().eval(), input_data=input_data)
+    verify_model(Topk3().float().eval(), input_data=input_data)
+    verify_model(Topk4().float().eval(), input_data=input_data)
+    verify_model(Topk5().float().eval(), input_data=input_data)
+    verify_model(Topk6().float().eval(), input_data=input_data)
+
+
 if __name__ == "__main__":
     # Single operator tests
     test_forward_add()
@@ -1587,6 +1642,8 @@ if __name__ == "__main__":
     test_forward_size()
     test_forward_view()
     test_forward_select()
+    test_forward_take()
+    test_forward_topk()
     test_forward_clone()
     test_forward_softplus()
     test_forward_softsign()

Reply via email to