This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new bde8a87381 [Unity][Relax] Support Dynamic Tensor as Index, torch
frontend (#15884)
bde8a87381 is described below
commit bde8a8738187737a064fb334d2512335b257182b
Author: Guoyao Li <[email protected]>
AuthorDate: Wed Oct 11 10:53:09 2023 -0400
[Unity][Relax] Support Dynamic Tensor as Index, torch frontend (#15884)
* add support for torch.tensor as index
* still don't fit in array indexing
* support at most one tensor index, to avoid error
* correct tests for tensor as index
* code style
* code style
* code style
* code style
---
python/tvm/relax/frontend/torch/fx_translator.py | 56 ++++++++++++------
tests/python/relax/test_frontend_dynamo.py | 72 ++++++++++++++++++++++++
2 files changed, 110 insertions(+), 18 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 3d150e3eed..ba7c9bd925 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1218,6 +1218,8 @@ class TorchFXImporter:
return getattr(self.env[node.args[0]], node.args[1])
def _getitem(self, node: fx.node.Node) -> relax.Var:
+ import torch
+
x = self.env[node.args[0]]
if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)):
return x[node.args[1]]
@@ -1226,48 +1228,66 @@ class TorchFXImporter:
return self.block_builder.emit(relax.TupleGetItem(x,
node.args[1]))
assert isinstance(x.struct_info, relax.TensorStructInfo)
- begin = []
- end = []
+ take_indices = []
+ take_axes = []
+ stride_begin = []
+ stride_end = []
stride = []
- axes = []
+ stride_axes = []
expand_dim = []
i = 0
shape = self.shape_of(x)
non_ellipsis_cnt = 0
for index in node.args[1]:
- if isinstance(index, (int, slice)):
+ if isinstance(index, (int, slice, torch.fx.node.Node)):
non_ellipsis_cnt += 1
for index in node.args[1]:
if isinstance(index, int):
- begin.append(index)
- end.append(index + 1)
+ stride_begin.append(index)
+ stride_end.append(index + 1)
stride.append(1)
- axes.append(i)
+ stride_axes.append(i)
i = i + 1
elif isinstance(index, slice):
- begin.append(0 if index.start is None else index.start)
- end.append(shape[i] if index.stop is None else index.stop)
+ stride_begin.append(0 if index.start is None else
index.start)
+ stride_end.append(shape[i] if index.stop is None else
index.stop)
stride.append(1 if index.step is None else index.step)
- axes.append(i)
+ stride_axes.append(i)
i = i + 1
elif index is None:
- expand_dim.append(len(axes) + len(expand_dim))
+ expand_dim.append(len(stride_axes) + len(expand_dim))
elif index is Ellipsis:
for _ in range(len(shape) - non_ellipsis_cnt):
- begin.append(0)
- end.append(shape[i])
+ stride_begin.append(0)
+ stride_end.append(shape[i])
stride.append(1)
- axes.append(i)
+ stride_axes.append(i)
i += 1
+ elif isinstance(index, torch.fx.node.Node):
+ node_index = self.env[index]
+ if not isinstance(node_index, relax.Expr):
+ raise ValueError(
+ "Unsupported index type for relax.op.take: " +
str(type(node_index))
+ )
+ take_indices.append(node_index)
+ take_axes.append(i)
+ i = i + 1
else:
raise ValueError("Unsupported index type: " +
str(type(index)))
while i < len(shape):
- begin.append(0)
- end.append(shape[i])
+ stride_begin.append(0)
+ stride_end.append(shape[i])
stride.append(1)
- axes.append(i)
+ stride_axes.append(i)
i += 1
- sliced = self.block_builder.emit(relax.op.strided_slice(x, axes,
begin, end, stride))
+ taken = x
+ if len(take_indices) > 1:
+ raise ValueError("Multiple tensors as index not yet supported")
+ for each_index, each_axis in zip(take_indices, take_axes):
+ taken = self.block_builder.emit(relax.op.take(taken,
each_index, each_axis))
+ sliced = self.block_builder.emit(
+ relax.op.strided_slice(taken, stride_axes, stride_begin,
stride_end, stride)
+ )
sliced_shape = list(self.shape_of(sliced))
for i in expand_dim:
sliced_shape.insert(i, 1)
diff --git a/tests/python/relax/test_frontend_dynamo.py
b/tests/python/relax/test_frontend_dynamo.py
index 66179d202e..2e2ee951c8 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -437,5 +437,77 @@ def test_masked_fill():
)
[email protected]_gpu
+def test_getitem():
+ import torch
+ from torch.nn import Module
+
+ class Select1(Module):
+ def forward(self, input1, input2):
+ result = input1[:, input2.argmax(dim=-1), :]
+ return result
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 77, 1280), dtype="float32"),
+ inp_1: R.Tensor((1, 77), dtype="float32"),
+ ) -> R.Tensor((1, 1, 1280), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1,), dtype="int64") = R.argmax(inp_1, axis=-1,
keepdims=False)
+ lv1: R.Tensor((1, 1, 1280), dtype="float32") = R.take(inp_0,
lv, axis=1)
+ lv2: R.Tensor((1, 1, 1280), dtype="float32") = R.strided_slice(
+ lv1,
+ axes=[0, 2],
+ begin=[0, 0],
+ end=[1, 1280],
+ strides=[1, 1],
+ assume_inbound=False,
+ )
+ lv3: R.Tensor((1, 1, 1280), dtype="float32") = R.reshape(lv2,
R.shape([1, 1, 1280]))
+ gv: R.Tensor((1, 1, 1280), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 77, 1280), dtype="float32")
+ ) -> R.Tensor((1, 77, 1280), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(1), R.prim_value(1),
dtype="int64"
+ )
+ lv1: R.Tensor((1, 77, 1280), dtype="float32") = R.take(inp_0,
lv, axis=0)
+ lv2: R.Tensor((1, 77, 1280), dtype="float32") =
R.strided_slice(
+ lv1,
+ axes=[1, 2],
+ begin=[0, 0],
+ end=[77, 1280],
+ strides=[1, 1],
+ assume_inbound=False,
+ )
+ lv3: R.Tensor((1, 77, 1280), dtype="float32") = R.reshape(
+ lv2, R.shape([1, 77, 1280])
+ )
+ gv: R.Tensor((1, 77, 1280), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ class Select2(Module):
+ def forward(self, input1):
+ result = input1[
+ torch.arange(1),
+ ]
+ return result
+
+ verify_dynamo_model(
+ Select1(), [([1, 77, 1280], "float32"), ([1, 77], "float32")], {},
Expected1
+ )
+ verify_dynamo_model(Select2(), [([1, 77, 1280], "float32")], {}, Expected2)
+
+
if __name__ == "__main__":
tvm.testing.main()