siju-samuel commented on a change in pull request #5332: [PYTORCH]Take, Topk op
support
URL: https://github.com/apache/incubator-tvm/pull/5332#discussion_r408729508
##########
File path: tests/python/frontend/pytorch/test_forward.py
##########
@@ -1545,6 +1545,61 @@ def forward(self, *args):
verify_model(Round1().float().eval(), input_data=input_data)
+def test_forward_take():
+ torch.set_grad_enabled(False)
+ class Take1(Module):
+ def forward(self, *args):
+ indices = torch.tensor([[0,0],[1,0]])
+ if torch.cuda.is_available():
Review comment:
My first commit it was not there, but the[ CI failed. ]
(https://ci.tvm.ai/blue/organizations/jenkins/tvm/detail/PR-5332/1/pipeline) I
think only tvm runs with CPU, torch uses GPU.
```
def test_forward_take():
torch.set_grad_enabled(False)
class Take1(Module):
def forward(self, *args):
return torch.take(args[0], torch.tensor([[0,0],[1,0]]))
class Take2(Module):
def forward(self, *args):
return torch.take(args[0], args[1])
input_data = torch.tensor([[1,2],[3,4]])
> verify_model(Take1().float().eval(), input_data=input_data)
tests/python/frontend/pytorch/test_forward.py:1559:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_
tests/python/frontend/pytorch/test_forward.py:157: in verify_model
baseline_outputs = baseline_model(*baseline_input)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py:532: in
__call__
result = self.forward(*input, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
_ _
self = Take1(), args = (tensor([[1, 2],
[3, 4]], device='cuda:0'),)
def forward(self, *args):
> return torch.take(args[0], torch.tensor([[0,0],[1,0]]))
E RuntimeError: Expected object of device type cuda but got device
type cpu for argument #2 'index' in call to _th_take
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services