This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push: new ba38222 [PYTORCH]where, addcdiv, addcmul op support (#5383) ba38222 is described below commit ba38222990feb7c2dbb18bd2e23ae7551d440fd3 Author: Samuel <siju.sam...@huawei.com> AuthorDate: Fri Apr 24 16:19:26 2020 +0530 [PYTORCH]where, addcdiv, addcmul op support (#5383) * [PYTORCH]Where, addcdiv, addcmul op support * Review comments fixed --- python/tvm/relay/frontend/pytorch.py | 72 +++++++++++++++------------ tests/python/frontend/pytorch/test_forward.py | 69 +++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 31 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0ade8af..a8eb9c4 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -279,15 +279,7 @@ def _select(): def _take(): def _impl(inputs, input_types): data = inputs[0] - import torch - - if isinstance(inputs[1], _expr.Var): - indices = _op.cast(inputs[1], "int32") - elif isinstance(inputs[1], torch.Tensor): - indices = _wrap_const(inputs[1].numpy()) - else: - msg = "Data type %s could not be parsed in take operator." % (type(inputs[1])) - raise AssertionError(msg) + indices = _op.cast(inputs[1], "int32") return _op.transform.take(data, indices=indices) return _impl @@ -337,6 +329,40 @@ def _repeat_interleave(): return _op.transform.repeat(data, repeats=repeats, axis=axis) return _impl + +def _addcdiv(): + def _impl(inputs, input_types): + data = inputs[0] + c = _expr.const(inputs[3]) + t1 = inputs[1] + t2 = inputs[2] + + return data + (c * (t1 / t2)) + return _impl + + +def _addcmul(): + def _impl(inputs, input_types): + data = inputs[0] + c = _expr.const(inputs[3]) + t1 = inputs[1] + t2 = inputs[2] + + return data + (c * (t1 * t2)) + return _impl + + +def _where(): + def _impl(inputs, input_types): + cond = inputs[0] + x = inputs[1] + y = inputs[2] + + return _op.where(cond, x, y) + + return _impl + + def _ones(): def _impl(inputs, input_types): data = inputs[0] @@ -1382,16 +1408,7 @@ def _bitwise_not(): def _bitwise_xor(): def _impl(inputs, input_types): lhs = inputs[0] - - import torch - if isinstance(inputs[1], _expr.Var): - rhs = inputs[1] - elif isinstance(inputs[1], torch.Tensor): - rhs = _wrap_const(inputs[1].numpy()) - else: - msg = "Data type %s could not be parsed in bitwise_xor operator." % (type(inputs[1])) - raise AssertionError(msg) - + rhs = inputs[1] lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int") rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int") @@ -1410,17 +1427,7 @@ def _logical_not(): def _logical_xor(): def _impl(inputs, input_types): lhs = _op.cast(inputs[0], "bool") - - import torch - if isinstance(inputs[1], _expr.Var): - rhs = inputs[1] - elif isinstance(inputs[1], torch.Tensor): - rhs = _wrap_const(inputs[1].numpy()) - else: - msg = "Data type %s could not be parsed in logical_xor operator." % (type(inputs[1])) - raise AssertionError(msg) - - rhs = _op.cast(rhs, "bool") + rhs = _op.cast(inputs[1], "bool") return _op.logical_xor(lhs, rhs) return _impl @@ -1551,6 +1558,8 @@ def _get_convert_map(prelude): "aten::arange" : _arange(), "aten::div" : _elemwise("divide"), "aten::div_" : _elemwise("divide"), + "aten::addcdiv" : _addcdiv(), + "aten::addcmul" : _addcmul(), "aten::ones" : _ones(), "aten::ones_like" : _ones_like(), "aten::zeros" : _zeros(), @@ -1570,6 +1579,7 @@ def _get_convert_map(prelude): "aten::split_with_sizes" : _split_with_sizes(), "aten::select" : _select(), "aten::take" : _take(), + "aten::where" : _where(), "aten::topk" : _topk(), "aten::relu" : _relu(), "aten::relu_" : _relu(), @@ -1832,7 +1842,7 @@ def _get_constant(node): tensor = node.t(attr_name) if len(tensor.shape) == 0: # tensor(0.1) return float(tensor) - return tensor + return _wrap_const(tensor.numpy()) elif ty == "DeviceObjType": return node.s(attr_name) elif ty == "FunctionType": diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 0a0e6bb..91c2661 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1888,6 +1888,72 @@ def test_forward_unary(): verify_model(Neg1().float().eval(), input_data=input_data) +def test_forward_where(): + torch.set_grad_enabled(False) + + class Where1(Module): + def forward(self, *args): + y = torch.ones([3, 2]) + if torch.cuda.is_available(): + y = y.cuda() + return torch.where(args[0] > 0, args[0], y) + + class Where2(Module): + def forward(self, *args): + return torch.where(args[0] > 0, args[0], args[1]) + + x = torch.rand([3, 2]).float() + verify_model(Where1().float().eval(), input_data=[x]) + y = torch.rand([3, 2]) + verify_model(Where2().float().eval(), input_data=[x, y]) + + +def test_forward_addcdiv(): + torch.set_grad_enabled(False) + + class Addcdiv1(Module): + def forward(self, *args): + t1 = torch.ones([3, 1]) + t2 = torch.ones([1, 3]) + if torch.cuda.is_available(): + t1 = t1.cuda() + t2 = t2.cuda() + return torch.addcdiv(args[0], 0.1, t1, t2) + + class Addcdiv2(Module): + def forward(self, *args): + return torch.addcdiv(args[0], 0.5, args[1], args[2]) + + input_data = torch.rand([1, 3]).float() + verify_model(Addcdiv1().float().eval(), input_data=input_data) + t1 = torch.rand([3, 1]).float() + t2 = torch.rand([1, 3]).float() + verify_model(Addcdiv2().float().eval(), input_data=[input_data, t1, t2]) + + +def test_forward_addcmul(): + torch.set_grad_enabled(False) + + class Addcmul1(Module): + def forward(self, *args): + t1 = torch.ones([3, 1]) + t2 = torch.ones([1, 3]) + if torch.cuda.is_available(): + t1 = t1.cuda() + t2 = t2.cuda() + return torch.addcmul(args[0], 0.1, t1, t2) + + class Addcmul2(Module): + def forward(self, *args): + return torch.addcmul(args[0], 0.5, args[1], args[2]) + + input_data = torch.rand([1, 3]).float() + verify_model(Addcmul1().float().eval(), input_data=input_data) + t1 = torch.rand([3, 1]).float() + t2 = torch.rand([1, 3]).float() + verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2]) + + if __name__ == "__main__": # Single operator tests test_forward_add() @@ -1933,6 +1999,9 @@ if __name__ == "__main__": test_forward_select() test_forward_take() test_forward_topk() + test_forward_where() + test_forward_addcdiv() + test_forward_addcmul() test_forward_clone() test_forward_softplus() test_forward_softsign()