This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 22db299 [PYTORCH]Unary Ops (#5378)
22db299 is described below
commit 22db299b33f05570db2a5a406bdb37b57198a822
Author: Samuel <[email protected]>
AuthorDate: Mon Apr 20 07:18:51 2020 +0530
[PYTORCH]Unary Ops (#5378)
---
python/tvm/relay/frontend/pytorch.py | 96 +++++-------------
tests/python/frontend/pytorch/test_forward.py | 141 ++++++++++++++++----------
2 files changed, 114 insertions(+), 123 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 9da3ecf..0ade8af 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -132,12 +132,16 @@ def _elemwise(name):
return get_relay_op(name)(data0, data1)
return _impl
-def _abs():
+
+def _unary(name):
def _impl(inputs, input_types):
- data = inputs[0]
- return _op.abs(data)
+ input_type = input_types[0]
+ data = _convert_elemwise_input(inputs[0], input_type)
+
+ return get_relay_op(name)(data)
return _impl
+
def _arange():
def _impl(inputs, input_types):
if len(inputs) == 5:
@@ -1254,26 +1258,6 @@ def _pad():
return _op.nn.pad(data, pad_width, pad_value)
return _impl
-def _sqrt():
- def _impl(inputs, input_types):
- data = inputs[0]
- return _op.tensor.sqrt(data)
- return _impl
-
-
-def _rsqrt():
- def _impl(inputs, input_types):
- data = inputs[0]
- return _op.tensor.rsqrt(data)
- return _impl
-
-
-def _ceil():
- def _impl(inputs, input_types):
- data = inputs[0]
- return _op.ceil(data)
- return _impl
-
def _clamp():
def _impl(inputs, input_types):
@@ -1284,20 +1268,6 @@ def _clamp():
return _impl
-def _floor():
- def _impl(inputs, input_types):
- data = inputs[0]
- return _op.floor(data)
- return _impl
-
-
-def _round():
- def _impl(inputs, input_types):
- data = inputs[0]
- return _op.round(data)
- return _impl
-
-
def _to():
def _impl(inputs, input_types):
data = inputs[0]
@@ -1375,17 +1345,6 @@ def _expand_as():
return inputs[0]
return _impl
-def _neg():
- def _impl(inputs, input_types):
- data = inputs[0]
- return _op.tensor.negative(data)
- return _impl
-
-def _tanh():
- def _impl(inputs, input_types):
- data = inputs[0]
- return _op.tensor.tanh(data)
- return _impl
def _Bool():
def _impl(inputs, input_types):
@@ -1467,18 +1426,6 @@ def _logical_xor():
return _impl
-def _isfinite():
- def _impl(inputs, input_types):
- return _op.isfinite(inputs[0])
- return _impl
-
-
-def _isnan():
- def _impl(inputs, input_types):
- return _op.isnan(inputs[0])
- return _impl
-
-
def _list_getitem(prelude):
def _impl(inputs, input_types):
return prelude.nth(inputs[0], _wrap_const(inputs[1]))
@@ -1601,7 +1548,6 @@ def _get_convert_map(prelude):
"aten::mul" : _elemwise("multiply"),
"aten::mul_" : _elemwise("multiply"),
"aten::pow" : _elemwise("power"),
- "aten::abs" : _abs(),
"aten::arange" : _arange(),
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
@@ -1683,12 +1629,26 @@ def _get_convert_map(prelude):
"aten::argmax" : _reduce("argmax"),
"aten::std" : _std(),
"aten::var" : _variance(),
- "aten::sqrt" : _sqrt(),
- "aten::rsqrt" : _rsqrt(),
- "aten::ceil" : _ceil(),
+ "aten::abs" : _unary("abs"),
+ "aten::neg" : _unary("negative"),
+ "aten::cos" : _unary("cos"),
+ "aten::sin" : _unary("sin"),
+ "aten::tan" : _unary("tan"),
+ "aten::tanh" : _unary("tanh"),
+ "aten::atan" : _unary("atan"),
+ "aten::log" : _unary("log"),
+ "aten::exp" : _unary("exp"),
+ "aten::erf" : _unary("erf"),
+ "aten::trunc" : _unary("trunc"),
+ "aten::sign" : _unary("sign"),
+ "aten::sqrt" : _unary("sqrt"),
+ "aten::rsqrt" : _unary("rsqrt"),
+ "aten::ceil" : _unary("ceil"),
+ "aten::floor" : _unary("floor"),
+ "aten::round" : _unary("round"),
+ "aten::isfinite" : _unary("isfinite"),
+ "aten::isnan" : _unary("isnan"),
"aten::clamp" : _clamp(),
- "aten::floor" : _floor(),
- "aten::round" : _round(),
"aten::detach" : _identity(),
"aten::upsample_bilinear2d" : _upsample("bilinear"),
"aten::upsample_nearest2d" :
_upsample("nearest_neighbor"),
@@ -1703,12 +1663,8 @@ def _get_convert_map(prelude):
"aten::logical_xor" : _logical_xor(),
"aten::bitwise_not" : _bitwise_not(),
"aten::bitwise_xor" : _bitwise_xor(),
- "aten::isfinite" : _isfinite(),
- "aten::isnan" : _isnan(),
"aten::Bool" : _Bool(),
"aten::Float" : _Float(),
- "aten::neg" : _neg(),
- "aten::tanh" : _tanh(),
"aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(),
"aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(),
"aten::mm" : _matmul(),
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index c692c5e..0a0e6bb 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1497,30 +1497,6 @@ def test_forward_isinf():
verify_model(IsInf1().float().eval(), input_data=input_data)
-def test_forward_rsqrt():
- torch.set_grad_enabled(False)
- input_shape = [1, 3, 10, 10]
-
- class Rsqrt1(Module):
- def forward(self, *args):
- return torch.rsqrt(args[0])
-
- input_data = torch.rand(input_shape).float()
- verify_model(Rsqrt1().float().eval(), input_data=input_data)
-
-
-def test_forward_ceil():
- torch.set_grad_enabled(False)
- input_shape = [1, 3, 10, 10]
-
- class Ceil1(Module):
- def forward(self, *args):
- return torch.ceil(args[0])
-
- input_data = torch.rand(input_shape).float()
- verify_model(Ceil1().float().eval(), input_data=input_data)
-
-
def test_forward_clamp():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
@@ -1543,30 +1519,6 @@ def test_forward_clamp():
verify_model(Clamp3().float().eval(), input_data=input_data)
-def test_forward_floor():
- torch.set_grad_enabled(False)
- input_shape = [1, 3, 10, 10]
-
- class Floor1(Module):
- def forward(self, *args):
- return torch.floor(args[0])
-
- input_data = torch.rand(input_shape).float()
- verify_model(Floor1().float().eval(), input_data=input_data)
-
-
-def test_forward_round():
- torch.set_grad_enabled(False)
- input_shape = [1, 3, 10, 10]
-
- class Round1(Module):
- def forward(self, *args):
- return torch.round(args[0])
-
- input_data = torch.rand(input_shape).float()
- verify_model(Round1().float().eval(), input_data=input_data)
-
-
def test_forward_ones():
torch.set_grad_enabled(False)
@@ -1849,6 +1801,93 @@ def test_forward_logical_xor():
verify_model(LogicalXor2().float().eval(), input_data=[lhs])
+def test_forward_unary():
+ torch.set_grad_enabled(False)
+
+ class Sqrt1(Module):
+ def forward(self, *args):
+ return torch.sqrt(args[0])
+
+ class RSqrt1(Module):
+ def forward(self, *args):
+ return torch.rsqrt(args[0])
+
+ class Ceil1(Module):
+ def forward(self, *args):
+ return torch.ceil(args[0])
+
+ class Floor1(Module):
+ def forward(self, *args):
+ return torch.floor(args[0])
+
+ class Round1(Module):
+ def forward(self, *args):
+ return torch.round(args[0])
+
+ class Cos1(Module):
+ def forward(self, *args):
+ return torch.cos(args[0])
+
+ class Sin1(Module):
+ def forward(self, *args):
+ return torch.sin(args[0])
+
+ class Tan1(Module):
+ def forward(self, *args):
+ return torch.tan(args[0])
+
+ class Tanh1(Module):
+ def forward(self, *args):
+ return torch.tanh(args[0])
+
+ class ATanh1(Module):
+ def forward(self, *args):
+ return torch.atan(args[0])
+
+ class Log1(Module):
+ def forward(self, *args):
+ return torch.log(args[0])
+
+ class Exp1(Module):
+ def forward(self, *args):
+ return torch.exp(args[0])
+
+ class Erf1(Module):
+ def forward(self, *args):
+ return torch.erf(args[0])
+
+ class Trunc1(Module):
+ def forward(self, *args):
+ return torch.trunc(args[0])
+
+ class Sign1(Module):
+ def forward(self, *args):
+ return torch.sign(args[0])
+
+ class Neg1(Module):
+ def forward(self, *args):
+ return torch.neg(args[0])
+
+ input_shape = [1, 3, 10, 10]
+ input_data = torch.rand(input_shape).float()
+ verify_model(Sqrt1().float().eval(), input_data=input_data)
+ verify_model(RSqrt1().float().eval(), input_data=input_data)
+ verify_model(Ceil1().float().eval(), input_data=input_data)
+ verify_model(Floor1().float().eval(), input_data=input_data)
+ verify_model(Round1().float().eval(), input_data=input_data)
+ verify_model(Cos1().float().eval(), input_data=input_data)
+ verify_model(Sin1().float().eval(), input_data=input_data)
+ verify_model(Tan1().float().eval(), input_data=input_data)
+ verify_model(Tanh1().float().eval(), input_data=input_data)
+ verify_model(ATanh1().float().eval(), input_data=input_data)
+ verify_model(Log1().float().eval(), input_data=input_data)
+ verify_model(Exp1().float().eval(), input_data=input_data)
+ verify_model(Erf1().float().eval(), input_data=input_data)
+ verify_model(Trunc1().float().eval(), input_data=input_data)
+ verify_model(Sign1().float().eval(), input_data=input_data)
+ verify_model(Neg1().float().eval(), input_data=input_data)
+
+
if __name__ == "__main__":
# Single operator tests
test_forward_add()
@@ -1907,12 +1946,8 @@ if __name__ == "__main__":
test_forward_mean()
test_forward_expand()
test_forward_pow()
- test_forward_abs()
- test_forward_rsqrt()
- test_forward_ceil()
+ test_forward_unary()
test_forward_clamp()
- test_forward_floor()
- test_forward_round()
test_forward_logical_not()
test_forward_bitwise_not()
test_forward_bitwise_xor()