This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 88f6f79 [Frontend][Pytorch]Add Pytorch advanced indexing (#6318)
88f6f79 is described below
commit 88f6f790a138a17e33c53bb70031329de162fa67
Author: Yao Wang <[email protected]>
AuthorDate: Fri Aug 21 18:54:26 2020 -0700
[Frontend][Pytorch]Add Pytorch advanced indexing (#6318)
* Add Pytorch advanced indexing
* Minor fix for test
* Fix for cuda
---
python/tvm/relay/frontend/pytorch.py | 53 +++++++++++++++++++++++++--
tests/python/frontend/pytorch/test_forward.py | 24 +++++++++++-
2 files changed, 72 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index b75f3f9..7237403 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -274,16 +274,18 @@ def _slice():
end[dim] = min(end[dim], int(inputs[3]))
else:
if isinstance(inputs[3], _expr.Call):
- end[dim] = np.asscalar(_infer_value(inputs[3],
{}).asnumpy().astype(np.int))
+ target_end = np.asscalar(_infer_value(inputs[3],
{}).asnumpy().astype(np.int))
else:
- end[dim] = inputs[3]
+ target_end = inputs[3]
+
+ end[dim] = min(end[dim], target_end)
strides.append(int(inputs[4]))
return _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(strides),
- slice_mode="size")
+ slice_mode="end")
return _impl
def _split():
@@ -1759,6 +1761,50 @@ def _one_hot():
return _impl
+def _index():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ indices = []
+ raw_indices = []
+ max_indices_len = -1
+ for index in inputs[1]:
+ if not isinstance(index, _expr.Constant):
+ try:
+ index = _expr.const(_infer_value(index, {}))
+ except Exception:
+ raise RuntimeError("Only supports constant indices for "
+ "pytorch advanced indexing ")
+ raw_indices.append(index)
+ cindex_len = index.data.shape[0]
+ if cindex_len > max_indices_len:
+ max_indices_len = cindex_len
+
+ for index in raw_indices:
+ cnp = index.data.asnumpy()
+ cindex_len = cnp.shape[0]
+ if cindex_len < max_indices_len:
+ cnp = np.tile(cnp, max_indices_len // cindex_len)
+ indices.append(cnp)
+
+ ret = []
+ slice_map = {}
+ for i in range(indices[0].shape[0]):
+ tmp = data
+ current_indices = []
+ for index in indices:
+ current_indices.append(index[i])
+ index_key = tuple(current_indices)
+ if index_key in slice_map:
+ tmp = slice_map[index_key]
+ else:
+ tmp = _op.take(tmp, _expr.const(index[i]), axis=0)
+ slice_map[index_key] = tmp
+ ret.append(_op.expand_dims(tmp, axis=0))
+
+ return _op.concatenate(ret, axis=0)
+ return _impl
+
+
def _meshgrid():
def _impl(inputs, input_types):
data = inputs[0]
@@ -2064,6 +2110,7 @@ def _get_convert_map(prelude):
"aten::type_as" : _type_as(),
"aten::gather" : _gather(),
"aten::index_select" : _select(),
+ "aten::index" : _index(),
}
return convert_map
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index e5c9634..ab0a4b0 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1202,13 +1202,13 @@ def test_forward_slice():
class Slice2(Module):
def forward(self, *args):
- return args[0][0, :, :, :]
+ return args[0][0, :, :-3, :]
class Slice3(Module):
def forward(self, *args):
x0 = torch.tensor(2) - torch.tensor(1)
x1 = torch.tensor(3) + torch.tensor(1)
- return args[0][:, x0:, :x1, :]
+ return args[0][:, x0:, 1:x1, :]
input_data = torch.rand(input_shape).float()
verify_model(Slice1().float().eval(), input_data=input_data)
@@ -2620,6 +2620,25 @@ def test_forward_matmul():
verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+def test_forward_index():
+ torch.set_grad_enabled(False)
+ input_shape = [3, 4, 5, 6]
+
+ class Index0(Module):
+ def forward(self, x):
+ return x[[0, 1], [0, 2], :2, 4]
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Index0().eval(), input_data=input_data)
+
+ class Index1(Module):
+ def forward(self, x):
+ return x[[0], [1, 2, 3, 0], [3, 1, 2, 2], [4, 2, 1, 0]]
+
+ input_data = torch.rand(input_shape).float()
+ verify_model(Index1().eval(), input_data=input_data)
+
+
def test_forward_pretrained_bert_base_uncased():
######################################################################
# This is an example how to run BERT models using TVM
@@ -2859,6 +2878,7 @@ if __name__ == "__main__":
test_adaptive_pool3d()
test_conv3d()
test_conv3d_transpose()
+ test_forward_index()
# Model tests
test_resnet18()