This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 430cb89 [Torch] Add support for split (#5174)
430cb89 is described below
commit 430cb89995bff298cca0adf6ef1087d071875d1a
Author: Wang Yucheng <[email protected]>
AuthorDate: Tue Mar 31 19:01:10 2020 +0800
[Torch] Add support for split (#5174)
* [Torch] Add support for split
* fix
* fix test class
---
python/tvm/relay/frontend/pytorch.py | 36 +++++++++++++++++++++++++++
tests/python/frontend/pytorch/test_forward.py | 24 ++++++++++++++++++
2 files changed, 60 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 6a26711..7dee58e 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -105,6 +105,36 @@ def _slice():
return _op.transform.strided_slice(data, begin, end, strides)
return _impl
+def _split():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ split_size = int(inputs[1])
+ dim = int(inputs[2])
+
+ split_index = split_size
+ indices = []
+ while split_index < _infer_shape(data)[dim]:
+ indices.append(split_index)
+ split_index += split_size
+
+ return _op.split(data, indices, dim)
+ return _impl
+
+def _split_with_sizes():
+ def _impl(inputs, inputs_types):
+ data = inputs[0]
+ dim = int(inputs[2])
+
+ split_index = 0
+ indices = []
+ sections = _infer_shape(inputs[1])
+ for i in range(len(sections) - 1):
+ split_index += sections[i]
+ indices.append(split_index)
+
+ return _op.split(data, indices, dim)
+ return _impl
+
def _select():
def _impl(inputs, input_types):
data = inputs[0]
@@ -886,6 +916,8 @@ _convert_map = {
"aten::unsqueeze" : _unsqueeze(),
"aten::cat" : _concatenate(),
"aten::slice" : _slice(),
+ "aten::split" : _split(),
+ "aten::split_with_sizes" : _split_with_sizes(),
"aten::select" : _select(),
"aten::relu" : _relu(),
"aten::relu_" : _relu(),
@@ -1415,6 +1447,10 @@ def from_pytorch(script_module, input_shapes,
custom_convert_map=None):
ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
+
+ if isinstance(ret[0], list):
+ ret[0] = _expr.Tuple(ret[0])
+
func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
return _module.IRModule.from_expr(func), tvm_params
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 1878266..6070d88 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -379,6 +379,29 @@ def test_forward_maxpool1d():
stride=2).eval(),
input_data)
+def test_forward_split():
+ torch.set_grad_enabled(False)
+ input_shape = [4, 10]
+
+ class Split(Module):
+ def __init__(self, split_size_or_sections, dim):
+ super(Split, self).__init__()
+ self.split_size_or_sections = split_size_or_sections
+ self.dim = dim
+
+ def forward(self, *args):
+ return torch.split(args[0], self.split_size_or_sections, self.dim)
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Split(2, 0).float().eval(),
+ input_data=input_data)
+ verify_model(Split(3, 1).float().eval(),
+ input_data=input_data)
+ verify_model(Split(4, 1).float().eval(),
+ input_data=input_data)
+ verify_model(Split([2, 3, 5], 1).float().eval(),
+ input_data=input_data)
+
def test_forward_avgpool():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
@@ -1077,6 +1100,7 @@ if __name__ == "__main__":
test_forward_expand()
test_forward_pow()
test_forward_chunk()
+ test_forward_split()
test_upsample()
test_to()
test_adaptive_pool3d()