This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 430cb89  [Torch] Add support for split (#5174)
430cb89 is described below

commit 430cb89995bff298cca0adf6ef1087d071875d1a
Author: Wang Yucheng <[email protected]>
AuthorDate: Tue Mar 31 19:01:10 2020 +0800

    [Torch] Add support for split (#5174)
    
    * [Torch] Add support for split
    
    * fix
    
    * fix test class
---
 python/tvm/relay/frontend/pytorch.py          | 36 +++++++++++++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py | 24 ++++++++++++++++++
 2 files changed, 60 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 6a26711..7dee58e 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -105,6 +105,36 @@ def _slice():
         return _op.transform.strided_slice(data, begin, end, strides)
     return _impl
 
+def _split():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        split_size = int(inputs[1])
+        dim = int(inputs[2])
+
+        split_index = split_size
+        indices = []
+        while split_index < _infer_shape(data)[dim]:
+            indices.append(split_index)
+            split_index += split_size
+
+        return _op.split(data, indices, dim)
+    return _impl
+
+def _split_with_sizes():
+    def _impl(inputs, inputs_types):
+        data = inputs[0]
+        dim = int(inputs[2])
+
+        split_index = 0
+        indices = []
+        sections = _infer_shape(inputs[1])
+        for i in range(len(sections) - 1):
+            split_index += sections[i]
+            indices.append(split_index)
+
+        return _op.split(data, indices, dim)
+    return _impl
+
 def _select():
     def _impl(inputs, input_types):
         data = inputs[0]
@@ -886,6 +916,8 @@ _convert_map = {
     "aten::unsqueeze"                       : _unsqueeze(),
     "aten::cat"                             : _concatenate(),
     "aten::slice"                           : _slice(),
+    "aten::split"                           : _split(),
+    "aten::split_with_sizes"                : _split_with_sizes(),
     "aten::select"                          : _select(),
     "aten::relu"                            : _relu(),
     "aten::relu_"                           : _relu(),
@@ -1415,6 +1447,10 @@ def from_pytorch(script_module, input_shapes, 
custom_convert_map=None):
 
     ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
                             output_index_map, ret_name)
+
+    if isinstance(ret[0], list):
+        ret[0] = _expr.Tuple(ret[0])
+
     func = tvm.relay.Function(_analysis.free_vars(ret[0]), ret[0])
 
     return _module.IRModule.from_expr(func), tvm_params
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 1878266..6070d88 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -379,6 +379,29 @@ def test_forward_maxpool1d():
                                     stride=2).eval(),
                 input_data)
 
+def test_forward_split():
+    torch.set_grad_enabled(False)
+    input_shape = [4, 10]
+
+    class Split(Module):
+        def __init__(self, split_size_or_sections, dim):
+            super(Split, self).__init__()
+            self.split_size_or_sections = split_size_or_sections
+            self.dim = dim
+
+        def forward(self, *args):
+            return torch.split(args[0], self.split_size_or_sections, self.dim)
+
+    input_data = torch.rand(input_shape).float()
+    verify_model(Split(2, 0).float().eval(),
+                input_data=input_data)
+    verify_model(Split(3, 1).float().eval(),
+                input_data=input_data)
+    verify_model(Split(4, 1).float().eval(),
+                input_data=input_data)
+    verify_model(Split([2, 3, 5], 1).float().eval(),
+                input_data=input_data)
+
 def test_forward_avgpool():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]
@@ -1077,6 +1100,7 @@ if __name__ == "__main__":
     test_forward_expand()
     test_forward_pow()
     test_forward_chunk()
+    test_forward_split()
     test_upsample()
     test_to()
     test_adaptive_pool3d()

Reply via email to