This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 8eb6584  [PYTORCH]expand bug fix (#5576)
8eb6584 is described below

commit 8eb65848988b66db06aaed42dd0667a95308a686
Author: Samuel <[email protected]>
AuthorDate: Wed May 13 05:39:41 2020 +0530

    [PYTORCH]expand bug fix (#5576)
---
 python/tvm/relay/frontend/pytorch.py          | 17 +++++++++++++----
 tests/python/frontend/pytorch/test_forward.py | 11 ++++++++++-
 2 files changed, 23 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 3af1051..d95a912 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1245,27 +1245,36 @@ def _matmul():
         return _op.nn.dense(data0, data1_t)
     return _impl
 
+
 def _expand():
     def _impl(inputs, input_types):
         data_in = inputs[0]
         if isinstance(data_in, _expr.Expr):
-            shape = _infer_shape(data_in)
+            shape = list(_infer_shape(data_in))
 
         ndims = len(shape)
         sizes = _infer_shape(inputs[1])
         out = inputs[0]
 
+        out_dims = len(sizes)
+        if ndims < out_dims:
+            num_newaxis = out_dims - ndims
+            out = _op.expand_dims(out, axis=0, num_newaxis=num_newaxis)
+            shape = [1] * num_newaxis + shape
+
         for i in range(ndims):
-            if sizes[i] in {-1, shape[i]}:
+            if sizes[i] == -1 or sizes[i] == shape[i]:
                 continue
             data = list()
             for temp in range(sizes[i]):
                 data.append(out)
-            call = _op.tensor.concatenate(data, i)
 
-        return call
+            out = _op.tensor.concatenate(data, i)
+
+        return out
     return _impl
 
+
 def _int():
     def _impl(inputs, input_types):
         if isinstance(inputs[0], _expr.Expr):
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index e1c276b..82a027f 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -902,15 +902,24 @@ def test_forward_mean():
 
 def test_forward_expand():
     torch.set_grad_enabled(False)
-    input_shape = [1, 3, 10, 10]
 
     class Expand1(Module):
         def forward(self, *args):
             return args[0].expand((3, -1, -1, -1))
 
+    input_shape = [1, 3, 10, 10]
     input_data = torch.rand(input_shape).float()
     verify_model(Expand1().float().eval(), input_data=input_data)
 
+    class Expand2(Module):
+        def forward(self, *args):
+            return args[0].expand((3, 3, 3, 1))
+
+    input_shape = [3, 1]
+    input_data = torch.rand(input_shape).float()
+    verify_model(Expand2().float().eval(), input_data=input_data)
+
+
 def test_forward_pow():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]

Reply via email to