This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new bde8a87381 [Unity][Relax] Support Dynamic Tensor as Index, torch 
frontend (#15884)
bde8a87381 is described below

commit bde8a8738187737a064fb334d2512335b257182b
Author: Guoyao Li <[email protected]>
AuthorDate: Wed Oct 11 10:53:09 2023 -0400

    [Unity][Relax] Support Dynamic Tensor as Index, torch frontend (#15884)
    
    * add support for torch.tensor as index
    
    * still don't fit in array indexing
    
    * support at most one tensor index, to avoid error
    
    * correct tests for tensor as index
    
    * code style
    
    * code style
    
    * code style
    
    * code style
---
 python/tvm/relax/frontend/torch/fx_translator.py | 56 ++++++++++++------
 tests/python/relax/test_frontend_dynamo.py       | 72 ++++++++++++++++++++++++
 2 files changed, 110 insertions(+), 18 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 3d150e3eed..ba7c9bd925 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1218,6 +1218,8 @@ class TorchFXImporter:
         return getattr(self.env[node.args[0]], node.args[1])
 
     def _getitem(self, node: fx.node.Node) -> relax.Var:
+        import torch
+
         x = self.env[node.args[0]]
         if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)):
             return x[node.args[1]]
@@ -1226,48 +1228,66 @@ class TorchFXImporter:
                 return self.block_builder.emit(relax.TupleGetItem(x, 
node.args[1]))
 
             assert isinstance(x.struct_info, relax.TensorStructInfo)
-            begin = []
-            end = []
+            take_indices = []
+            take_axes = []
+            stride_begin = []
+            stride_end = []
             stride = []
-            axes = []
+            stride_axes = []
             expand_dim = []
             i = 0
             shape = self.shape_of(x)
             non_ellipsis_cnt = 0
             for index in node.args[1]:
-                if isinstance(index, (int, slice)):
+                if isinstance(index, (int, slice, torch.fx.node.Node)):
                     non_ellipsis_cnt += 1
             for index in node.args[1]:
                 if isinstance(index, int):
-                    begin.append(index)
-                    end.append(index + 1)
+                    stride_begin.append(index)
+                    stride_end.append(index + 1)
                     stride.append(1)
-                    axes.append(i)
+                    stride_axes.append(i)
                     i = i + 1
                 elif isinstance(index, slice):
-                    begin.append(0 if index.start is None else index.start)
-                    end.append(shape[i] if index.stop is None else index.stop)
+                    stride_begin.append(0 if index.start is None else 
index.start)
+                    stride_end.append(shape[i] if index.stop is None else 
index.stop)
                     stride.append(1 if index.step is None else index.step)
-                    axes.append(i)
+                    stride_axes.append(i)
                     i = i + 1
                 elif index is None:
-                    expand_dim.append(len(axes) + len(expand_dim))
+                    expand_dim.append(len(stride_axes) + len(expand_dim))
                 elif index is Ellipsis:
                     for _ in range(len(shape) - non_ellipsis_cnt):
-                        begin.append(0)
-                        end.append(shape[i])
+                        stride_begin.append(0)
+                        stride_end.append(shape[i])
                         stride.append(1)
-                        axes.append(i)
+                        stride_axes.append(i)
                         i += 1
+                elif isinstance(index, torch.fx.node.Node):
+                    node_index = self.env[index]
+                    if not isinstance(node_index, relax.Expr):
+                        raise ValueError(
+                            "Unsupported index type for relax.op.take: " + 
str(type(node_index))
+                        )
+                    take_indices.append(node_index)
+                    take_axes.append(i)
+                    i = i + 1
                 else:
                     raise ValueError("Unsupported index type: " + 
str(type(index)))
             while i < len(shape):
-                begin.append(0)
-                end.append(shape[i])
+                stride_begin.append(0)
+                stride_end.append(shape[i])
                 stride.append(1)
-                axes.append(i)
+                stride_axes.append(i)
                 i += 1
-            sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, 
begin, end, stride))
+            taken = x
+            if len(take_indices) > 1:
+                raise ValueError("Multiple tensors as index not yet supported")
+            for each_index, each_axis in zip(take_indices, take_axes):
+                taken = self.block_builder.emit(relax.op.take(taken, 
each_index, each_axis))
+            sliced = self.block_builder.emit(
+                relax.op.strided_slice(taken, stride_axes, stride_begin, 
stride_end, stride)
+            )
             sliced_shape = list(self.shape_of(sliced))
             for i in expand_dim:
                 sliced_shape.insert(i, 1)
diff --git a/tests/python/relax/test_frontend_dynamo.py 
b/tests/python/relax/test_frontend_dynamo.py
index 66179d202e..2e2ee951c8 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -437,5 +437,77 @@ def test_masked_fill():
     )
 
 
[email protected]_gpu
+def test_getitem():
+    import torch
+    from torch.nn import Module
+
+    class Select1(Module):
+        def forward(self, input1, input2):
+            result = input1[:, input2.argmax(dim=-1), :]
+            return result
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 77, 1280), dtype="float32"),
+            inp_1: R.Tensor((1, 77), dtype="float32"),
+        ) -> R.Tensor((1, 1, 1280), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1,), dtype="int64") = R.argmax(inp_1, axis=-1, 
keepdims=False)
+                lv1: R.Tensor((1, 1, 1280), dtype="float32") = R.take(inp_0, 
lv, axis=1)
+                lv2: R.Tensor((1, 1, 1280), dtype="float32") = R.strided_slice(
+                    lv1,
+                    axes=[0, 2],
+                    begin=[0, 0],
+                    end=[1, 1280],
+                    strides=[1, 1],
+                    assume_inbound=False,
+                )
+                lv3: R.Tensor((1, 1, 1280), dtype="float32") = R.reshape(lv2, 
R.shape([1, 1, 1280]))
+                gv: R.Tensor((1, 1, 1280), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 77, 1280), dtype="float32")
+        ) -> R.Tensor((1, 77, 1280), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((1,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(1), R.prim_value(1), 
dtype="int64"
+                )
+                lv1: R.Tensor((1, 77, 1280), dtype="float32") = R.take(inp_0, 
lv, axis=0)
+                lv2: R.Tensor((1, 77, 1280), dtype="float32") = 
R.strided_slice(
+                    lv1,
+                    axes=[1, 2],
+                    begin=[0, 0],
+                    end=[77, 1280],
+                    strides=[1, 1],
+                    assume_inbound=False,
+                )
+                lv3: R.Tensor((1, 77, 1280), dtype="float32") = R.reshape(
+                    lv2, R.shape([1, 77, 1280])
+                )
+                gv: R.Tensor((1, 77, 1280), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    class Select2(Module):
+        def forward(self, input1):
+            result = input1[
+                torch.arange(1),
+            ]
+            return result
+
+    verify_dynamo_model(
+        Select1(), [([1, 77, 1280], "float32"), ([1, 77], "float32")], {}, 
Expected1
+    )
+    verify_dynamo_model(Select2(), [([1, 77, 1280], "float32")], {}, Expected2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to