This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new ba38222 [PYTORCH]where, addcdiv, addcmul op support (#5383)
ba38222 is described below
commit ba38222990feb7c2dbb18bd2e23ae7551d440fd3
Author: Samuel <[email protected]>
AuthorDate: Fri Apr 24 16:19:26 2020 +0530
[PYTORCH]where, addcdiv, addcmul op support (#5383)
* [PYTORCH]Where, addcdiv, addcmul op support
* Review comments fixed
---
python/tvm/relay/frontend/pytorch.py | 72 +++++++++++++++------------
tests/python/frontend/pytorch/test_forward.py | 69 +++++++++++++++++++++++++
2 files changed, 110 insertions(+), 31 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 0ade8af..a8eb9c4 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -279,15 +279,7 @@ def _select():
def _take():
def _impl(inputs, input_types):
data = inputs[0]
- import torch
-
- if isinstance(inputs[1], _expr.Var):
- indices = _op.cast(inputs[1], "int32")
- elif isinstance(inputs[1], torch.Tensor):
- indices = _wrap_const(inputs[1].numpy())
- else:
- msg = "Data type %s could not be parsed in take operator." %
(type(inputs[1]))
- raise AssertionError(msg)
+ indices = _op.cast(inputs[1], "int32")
return _op.transform.take(data, indices=indices)
return _impl
@@ -337,6 +329,40 @@ def _repeat_interleave():
return _op.transform.repeat(data, repeats=repeats, axis=axis)
return _impl
+
+def _addcdiv():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ c = _expr.const(inputs[3])
+ t1 = inputs[1]
+ t2 = inputs[2]
+
+ return data + (c * (t1 / t2))
+ return _impl
+
+
+def _addcmul():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ c = _expr.const(inputs[3])
+ t1 = inputs[1]
+ t2 = inputs[2]
+
+ return data + (c * (t1 * t2))
+ return _impl
+
+
+def _where():
+ def _impl(inputs, input_types):
+ cond = inputs[0]
+ x = inputs[1]
+ y = inputs[2]
+
+ return _op.where(cond, x, y)
+
+ return _impl
+
+
def _ones():
def _impl(inputs, input_types):
data = inputs[0]
@@ -1382,16 +1408,7 @@ def _bitwise_not():
def _bitwise_xor():
def _impl(inputs, input_types):
lhs = inputs[0]
-
- import torch
- if isinstance(inputs[1], _expr.Var):
- rhs = inputs[1]
- elif isinstance(inputs[1], torch.Tensor):
- rhs = _wrap_const(inputs[1].numpy())
- else:
- msg = "Data type %s could not be parsed in bitwise_xor operator."
% (type(inputs[1]))
- raise AssertionError(msg)
-
+ rhs = inputs[1]
lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else
_op.cast(lhs, "int")
rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else
_op.cast(rhs, "int")
@@ -1410,17 +1427,7 @@ def _logical_not():
def _logical_xor():
def _impl(inputs, input_types):
lhs = _op.cast(inputs[0], "bool")
-
- import torch
- if isinstance(inputs[1], _expr.Var):
- rhs = inputs[1]
- elif isinstance(inputs[1], torch.Tensor):
- rhs = _wrap_const(inputs[1].numpy())
- else:
- msg = "Data type %s could not be parsed in logical_xor operator."
% (type(inputs[1]))
- raise AssertionError(msg)
-
- rhs = _op.cast(rhs, "bool")
+ rhs = _op.cast(inputs[1], "bool")
return _op.logical_xor(lhs, rhs)
return _impl
@@ -1551,6 +1558,8 @@ def _get_convert_map(prelude):
"aten::arange" : _arange(),
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
+ "aten::addcdiv" : _addcdiv(),
+ "aten::addcmul" : _addcmul(),
"aten::ones" : _ones(),
"aten::ones_like" : _ones_like(),
"aten::zeros" : _zeros(),
@@ -1570,6 +1579,7 @@ def _get_convert_map(prelude):
"aten::split_with_sizes" : _split_with_sizes(),
"aten::select" : _select(),
"aten::take" : _take(),
+ "aten::where" : _where(),
"aten::topk" : _topk(),
"aten::relu" : _relu(),
"aten::relu_" : _relu(),
@@ -1832,7 +1842,7 @@ def _get_constant(node):
tensor = node.t(attr_name)
if len(tensor.shape) == 0: # tensor(0.1)
return float(tensor)
- return tensor
+ return _wrap_const(tensor.numpy())
elif ty == "DeviceObjType":
return node.s(attr_name)
elif ty == "FunctionType":
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 0a0e6bb..91c2661 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1888,6 +1888,72 @@ def test_forward_unary():
verify_model(Neg1().float().eval(), input_data=input_data)
+def test_forward_where():
+ torch.set_grad_enabled(False)
+
+ class Where1(Module):
+ def forward(self, *args):
+ y = torch.ones([3, 2])
+ if torch.cuda.is_available():
+ y = y.cuda()
+ return torch.where(args[0] > 0, args[0], y)
+
+ class Where2(Module):
+ def forward(self, *args):
+ return torch.where(args[0] > 0, args[0], args[1])
+
+ x = torch.rand([3, 2]).float()
+ verify_model(Where1().float().eval(), input_data=[x])
+ y = torch.rand([3, 2])
+ verify_model(Where2().float().eval(), input_data=[x, y])
+
+
+def test_forward_addcdiv():
+ torch.set_grad_enabled(False)
+
+ class Addcdiv1(Module):
+ def forward(self, *args):
+ t1 = torch.ones([3, 1])
+ t2 = torch.ones([1, 3])
+ if torch.cuda.is_available():
+ t1 = t1.cuda()
+ t2 = t2.cuda()
+ return torch.addcdiv(args[0], 0.1, t1, t2)
+
+ class Addcdiv2(Module):
+ def forward(self, *args):
+ return torch.addcdiv(args[0], 0.5, args[1], args[2])
+
+ input_data = torch.rand([1, 3]).float()
+ verify_model(Addcdiv1().float().eval(), input_data=input_data)
+ t1 = torch.rand([3, 1]).float()
+ t2 = torch.rand([1, 3]).float()
+ verify_model(Addcdiv2().float().eval(), input_data=[input_data, t1, t2])
+
+
+def test_forward_addcmul():
+ torch.set_grad_enabled(False)
+
+ class Addcmul1(Module):
+ def forward(self, *args):
+ t1 = torch.ones([3, 1])
+ t2 = torch.ones([1, 3])
+ if torch.cuda.is_available():
+ t1 = t1.cuda()
+ t2 = t2.cuda()
+ return torch.addcmul(args[0], 0.1, t1, t2)
+
+ class Addcmul2(Module):
+ def forward(self, *args):
+ return torch.addcmul(args[0], 0.5, args[1], args[2])
+
+ input_data = torch.rand([1, 3]).float()
+ verify_model(Addcmul1().float().eval(), input_data=input_data)
+ t1 = torch.rand([3, 1]).float()
+ t2 = torch.rand([1, 3]).float()
+ verify_model(Addcmul2().float().eval(), input_data=[input_data, t1, t2])
+
+
if __name__ == "__main__":
# Single operator tests
test_forward_add()
@@ -1933,6 +1999,9 @@ if __name__ == "__main__":
test_forward_select()
test_forward_take()
test_forward_topk()
+ test_forward_where()
+ test_forward_addcdiv()
+ test_forward_addcmul()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()