This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 8eb6584 [PYTORCH]expand bug fix (#5576)
8eb6584 is described below
commit 8eb65848988b66db06aaed42dd0667a95308a686
Author: Samuel <[email protected]>
AuthorDate: Wed May 13 05:39:41 2020 +0530
[PYTORCH]expand bug fix (#5576)
---
python/tvm/relay/frontend/pytorch.py | 17 +++++++++++++----
tests/python/frontend/pytorch/test_forward.py | 11 ++++++++++-
2 files changed, 23 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 3af1051..d95a912 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1245,27 +1245,36 @@ def _matmul():
return _op.nn.dense(data0, data1_t)
return _impl
+
def _expand():
def _impl(inputs, input_types):
data_in = inputs[0]
if isinstance(data_in, _expr.Expr):
- shape = _infer_shape(data_in)
+ shape = list(_infer_shape(data_in))
ndims = len(shape)
sizes = _infer_shape(inputs[1])
out = inputs[0]
+ out_dims = len(sizes)
+ if ndims < out_dims:
+ num_newaxis = out_dims - ndims
+ out = _op.expand_dims(out, axis=0, num_newaxis=num_newaxis)
+ shape = [1] * num_newaxis + shape
+
for i in range(ndims):
- if sizes[i] in {-1, shape[i]}:
+ if sizes[i] == -1 or sizes[i] == shape[i]:
continue
data = list()
for temp in range(sizes[i]):
data.append(out)
- call = _op.tensor.concatenate(data, i)
- return call
+ out = _op.tensor.concatenate(data, i)
+
+ return out
return _impl
+
def _int():
def _impl(inputs, input_types):
if isinstance(inputs[0], _expr.Expr):
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index e1c276b..82a027f 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -902,15 +902,24 @@ def test_forward_mean():
def test_forward_expand():
torch.set_grad_enabled(False)
- input_shape = [1, 3, 10, 10]
class Expand1(Module):
def forward(self, *args):
return args[0].expand((3, -1, -1, -1))
+ input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(Expand1().float().eval(), input_data=input_data)
+ class Expand2(Module):
+ def forward(self, *args):
+ return args[0].expand((3, 3, 3, 1))
+
+ input_shape = [3, 1]
+ input_data = torch.rand(input_shape).float()
+ verify_model(Expand2().float().eval(), input_data=input_data)
+
+
def test_forward_pow():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]