This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fab67a9af9 [Relax][PyTorch] Support tensor manipulation and creation 
ops for ExportedProgram importer (#17429)
fab67a9af9 is described below

commit fab67a9af918607542d8e6a895d53cc28030d7bd
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Wed Oct 2 09:33:01 2024 +0900

    [Relax][PyTorch] Support tensor manipulation and creation ops for 
ExportedProgram importer (#17429)
    
    * support cat and concat
    
    * support cumsum
    
    * support expand
    
    * support permute
    
    * support squeeze
    
    * support tile
    
    * support transpose
    
    * support unsqueeze
    
    * add test for flatten
    
    * support repeat
    
    * add test for reshape
    
    * support select and slice
    
    * support arange
    
    * support empty
    
    * support fill
    
    * support new_ones
    
    * support _to_copy
    
    * support split
    
    * add test for unbind
    
    * support clone
---
 .../frontend/torch/base_fx_graph_translator.py     | 161 +++++
 .../frontend/torch/exported_program_translator.py  |  39 +
 python/tvm/relax/frontend/torch/fx_translator.py   | 139 ----
 .../relax/test_frontend_from_exported_program.py   | 781 +++++++++++++++++++++
 4 files changed, 981 insertions(+), 139 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 52784dc8c3..322ee04e0c 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -730,6 +730,51 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
     ########## Manipulation ##########
 
+    def _cat(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
+        return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
+
+    def _cumsum(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
None)
+        if "dtype" in node.kwargs:
+            dtype = self._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
+        else:
+            dtype = None
+        if "out" in node.kwargs:
+            raise ValueError("specifying out for cumsum is not supported yet")
+
+        return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
+
+    def _expand(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        sizes = args[1:] if len(args) > 2 else args[1]
+        broadcast_shape, in_shape = [], self.shape_of(args[0])
+        for idx, i in enumerate(sizes):
+            if isinstance(i, int) and i == -1:
+                broadcast_shape.append(in_shape[idx])
+            else:
+                broadcast_shape.append(i)
+        return self.block_builder.emit(relax.op.broadcast_to(args[0], 
broadcast_shape))
+
+    def _permute(self, node: fx.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        args = self.retrieve_args(node)
+        x = args[0]
+        dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
+        return self.block_builder.emit(relax.op.permute_dims(x, dims))
+
+    def _repeat(self, node: fx.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        args = self.retrieve_args(node)
+        x = args[0]
+        dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
+        return self.block_builder.emit(relax.op.tile(x, dims))
+
     def _reshape(self, node: fx.Node) -> relax.Var:
         import torch  # type: ignore
 
@@ -738,6 +783,122 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
         return self.block_builder.emit(relax.op.reshape(x, dims))
 
+    def _split(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        split_size = node.args[1]
+        dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
+        if isinstance(split_size, (list, tuple)):
+            n_section = []
+            for s in split_size[:-1]:
+                cum_sum = 0 if not n_section else n_section[-1]
+                n_section.append(s + cum_sum)
+        else:
+            n_section = (self.shape_of(x)[dim].value + split_size - 1) // 
split_size
+        return self.block_builder.emit(relax.op.split(x, n_section, dim))
+
+    def _squeeze(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
None)
+        return self.block_builder.emit(relax.op.squeeze(x, dim))
+
+    def _tile(self, node: fx.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        args = self.retrieve_args(node)
+        x = args[0]
+        dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
+        return self.block_builder.emit(relax.op.tile(x, dims))
+
+    def _transpose(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        full_idx = list(range(len(self.shape_of(args[0]))))
+        full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], 
full_idx[args[1]]
+        return self.block_builder.emit(relax.op.permute_dims(args[0], 
full_idx))
+
+    ########## Creation ##########
+
+    def _to_copy(self, node: fx.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        x = self.env[node.args[0]]
+        if len(node.args) == 2:
+            if isinstance(node.args[1], torch.dtype):
+                dtype = self._convert_data_type(node.args[1], self.env)
+                return self.block_builder.emit(relax.op.astype(x, dtype))
+        elif "dtype" in node.kwargs:
+            dtype = self._convert_data_type(node.kwargs["dtype"], self.env)
+            return self.block_builder.emit(relax.op.astype(x, dtype))
+        return x
+
+    def _arange(self, node: fx.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        start_end_step = [None, None, None]
+        if "start" in node.kwargs:
+            start_end_step[0] = node.kwargs["start"]
+        if "end" in node.kwargs:
+            start_end_step[1] = node.kwargs["end"]
+        if "step" in node.kwargs:
+            start_end_step[2] = node.kwargs["step"]
+
+        if len(node.args) == 1:
+            assert start_end_step[1] is None
+            start_end_step[1] = node.args[0]
+        elif len(node.args) == 2:
+            assert start_end_step[0] is None
+            assert start_end_step[1] is None
+            start_end_step[0] = node.args[0]
+            start_end_step[1] = node.args[1]
+        elif len(node.args) == 3:
+            assert start_end_step[0] is None
+            assert start_end_step[1] is None
+            assert start_end_step[2] is None
+            start_end_step[0] = node.args[0]
+            start_end_step[1] = node.args[1]
+            start_end_step[2] = node.args[2]
+
+        if start_end_step[0] is None:
+            start_end_step[0] = 0
+        if start_end_step[2] is None:
+            start_end_step[2] = 1
+
+        if "dtype" in node.kwargs:
+            dtype = self._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
+        elif any([isinstance(x, float) for x in start_end_step]):
+            dtype = self._convert_data_type(torch.get_default_dtype())
+        else:
+            dtype = "int64"
+        start_end_step = [
+            self.env[x] if isinstance(x, torch.fx.Node) else x for x in 
start_end_step
+        ]
+        return self.block_builder.emit(relax.op.arange(*start_end_step, 
dtype=dtype))
+
+    def _empty(self, node: fx.Node) -> relax.Var:
+        dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
+        return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))
+
+    def _fill(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dtype = x.struct_info.dtype
+        value = args[1] if isinstance(args[1], relax.Expr) else 
relax.const(args[1], dtype)
+        return self.block_builder.emit(relax.op.full(x.struct_info.shape, 
value, dtype))
+
+    def _new_ones(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        self_var = args[0]
+        size = args[1] if isinstance(args[1], (list, tuple)) else args[1:]
+        if not isinstance(size, (list, tuple)):
+            size = (size,)
+        size = relax.ShapeExpr(size)
+        return self.block_builder.emit(
+            relax.op.full(
+                size,
+                relax.const(1, self_var.struct_info.dtype),
+                self_var.struct_info.dtype,
+            )
+        )
+
     ########## Others ##########
 
     def _getitem(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 64583d7509..1401a0bcef 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -162,6 +162,22 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         scale_factor = node.args[3] if len(node.args) > 3 else 
node.kwargs.get("scale_factor", None)
         return self._upsample_impl(x, size, align_corners, scale_factor, 
"nearest_neighbor")
 
+    ########## Manipulation ##########
+
+    def _select(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1]
+        index = relax.const(node.args[2], "int64")
+        return self.block_builder.emit(relax.op.take(x, index, dim))
+
+    def _slice(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        axes = [node.args[1]]
+        begin = [node.args[2]]
+        end = [node.args[3]]
+        stride = [node.args[4] if len(node.args) > 4 else 1]
+        return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, 
end, stride))
+
     def create_convert_map(
         self,
     ) -> Dict[str, Callable[[fx.Node], relax.Var]]:
@@ -249,7 +265,30 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "argmax.default": self._argmax_argmin(relax.op.argmax),
             "argmin.default": self._argmax_argmin(relax.op.argmin),
             # tensor manipulation
+            "cat.default": self._cat,
+            "concat.default": self._cat,
+            "cumsum.default": self._cumsum,
+            "expand.default": self._expand,
+            "permute.default": self._permute,
+            "repeat.default": self._repeat,
+            "select.int": self._select,
+            "slice.Tensor": self._slice,
+            "split.Tensor": self._split,
+            "squeeze.default": self._squeeze,
+            "squeeze.dim": self._squeeze,
+            "tile.default": self._tile,
+            "transpose.int": self._transpose,
+            "unsqueeze.default": lambda node: self.block_builder.emit(
+                relax.op.expand_dims(self.env[node.args[0]], node.args[1])
+            ),
             "view.default": self._reshape,
+            # tensor creation
+            "_to_copy.default": self._to_copy,
+            "arange.start": self._arange,
+            "clone.default": lambda node: self.env[node.args[0]],
+            "empty.memory_format": self._empty,
+            "fill.Scalar": self._fill,
+            "new_ones.default": self._new_ones,
             # other
             "getitem": self._getitem,
         }
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index c60c7c3953..9fbc95fa7c 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -380,41 +380,12 @@ class TorchFXImporter(BaseFXGraphImporter):
 
     ########## Manipulation ##########
 
-    def _cat(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
-        return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
-
     def _chunk(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         chunks = node.args[1]
         dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
         return self.block_builder.emit(relax.op.split(x, chunks, dim))
 
-    def _cumsum(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-
-        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
None)
-        if "dtype" in node.kwargs:
-            dtype = self._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
-        else:
-            dtype = None
-        if "out" in node.kwargs:
-            raise ValueError("specifying out for cumsum is not supported yet")
-
-        return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
-
-    def _expand(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        sizes = args[1:] if len(args) > 2 else args[1]
-        broadcast_shape, in_shape = [], self.shape_of(args[0])
-        for idx, i in enumerate(sizes):
-            if isinstance(i, int) and i == -1:
-                broadcast_shape.append(in_shape[idx])
-            else:
-                broadcast_shape.append(i)
-        return self.block_builder.emit(relax.op.broadcast_to(args[0], 
broadcast_shape))
-
     def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var:
         shape = self.shape_of(x)
         start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim
@@ -440,22 +411,6 @@ class TorchFXImporter(BaseFXGraphImporter):
         end_dim = module.end_dim
         return self._flatten_impl(x, start_dim, end_dim)
 
-    def _permute(self, node: fx.Node) -> relax.Var:
-        import torch  # type: ignore
-
-        args = self.retrieve_args(node)
-        x = args[0]
-        dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
-        return self.block_builder.emit(relax.op.permute_dims(x, dims))
-
-    def _repeat(self, node: fx.Node) -> relax.Var:
-        import torch  # type: ignore
-
-        args = self.retrieve_args(node)
-        x = args[0]
-        dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
-        return self.block_builder.emit(relax.op.tile(x, dims))
-
     def _size(self, node: fx.Node) -> relax.Expr:
         x = self.env[node.args[0]]
         shape = self.shape_of(x)
@@ -466,87 +421,8 @@ class TorchFXImporter(BaseFXGraphImporter):
         idx = node.args[1]
         return self.shape_of(x)[idx].value
 
-    def _split(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        split_size = node.args[1]
-        dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
-        if isinstance(split_size, (list, tuple)):
-            n_section = []
-            for s in split_size[:-1]:
-                cum_sum = 0 if not n_section else n_section[-1]
-                n_section.append(s + cum_sum)
-        else:
-            n_section = (self.shape_of(x)[dim].value + split_size - 1) // 
split_size
-        return self.block_builder.emit(relax.op.split(x, n_section, dim))
-
-    def _squeeze(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 
None)
-        return self.block_builder.emit(relax.op.squeeze(x, dim))
-
-    def _tile(self, node: fx.Node) -> relax.Var:
-        import torch  # type: ignore
-
-        args = self.retrieve_args(node)
-        x = args[0]
-        dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
-        return self.block_builder.emit(relax.op.tile(x, dims))
-
-    def _transpose(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        full_idx = list(range(len(self.shape_of(args[0]))))
-        full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], 
full_idx[args[1]]
-        return self.block_builder.emit(relax.op.permute_dims(args[0], 
full_idx))
-
     ########## Creation ##########
 
-    def _arange(self, node: fx.Node) -> relax.Var:
-        import torch  # type: ignore
-
-        start_end_step = [None, None, None]
-        if "start" in node.kwargs:
-            start_end_step[0] = node.kwargs["start"]
-        if "end" in node.kwargs:
-            start_end_step[1] = node.kwargs["end"]
-        if "step" in node.kwargs:
-            start_end_step[2] = node.kwargs["step"]
-
-        if len(node.args) == 1:
-            assert start_end_step[1] is None
-            start_end_step[1] = node.args[0]
-        elif len(node.args) == 2:
-            assert start_end_step[0] is None
-            assert start_end_step[1] is None
-            start_end_step[0] = node.args[0]
-            start_end_step[1] = node.args[1]
-        elif len(node.args) == 3:
-            assert start_end_step[0] is None
-            assert start_end_step[1] is None
-            assert start_end_step[2] is None
-            start_end_step[0] = node.args[0]
-            start_end_step[1] = node.args[1]
-            start_end_step[2] = node.args[2]
-
-        if start_end_step[0] is None:
-            start_end_step[0] = 0
-        if start_end_step[2] is None:
-            start_end_step[2] = 1
-
-        if "dtype" in node.kwargs:
-            dtype = self._convert_data_type(str(node.kwargs["dtype"]), 
self.env)
-        elif any([isinstance(x, float) for x in start_end_step]):
-            dtype = self._convert_data_type(torch.get_default_dtype())
-        else:
-            dtype = "int64"
-        start_end_step = [
-            self.env[x] if isinstance(x, torch.fx.Node) else x for x in 
start_end_step
-        ]
-        return self.block_builder.emit(relax.op.arange(*start_end_step, 
dtype=dtype))
-
-    def _empty(self, node: fx.Node) -> relax.Var:
-        dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
-        return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))
-
     def _inplace_fill(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
@@ -596,21 +472,6 @@ class TorchFXImporter(BaseFXGraphImporter):
         values = self.block_builder.emit(relax.op.full_like(x, rx_value))
         return self.block_builder.emit(relax.op.where(mask, values, x))
 
-    def _new_ones(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        self_var = args[0]
-        size = args[1] if isinstance(args[1], (list, tuple)) else args[1:]
-        if not isinstance(size, (list, tuple)):
-            size = (size,)
-        size = relax.ShapeExpr(size)
-        return self.block_builder.emit(
-            relax.op.full(
-                size,
-                relax.const(1, self_var.struct_info.dtype),
-                self_var.struct_info.dtype,
-            )
-        )
-
     def _ones(self, node: fx.Node) -> relax.Var:
         import torch
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 7c887d9b96..65890ff697 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2734,6 +2734,582 @@ def test_argmax_argmin():
     verify_model(Argmin2(), example_args, {}, expected_argmin2)
 
 
+def test_cat_concat():
+    class Cat0(Module):
+        def forward(self, x, y):
+            return torch.cat((x, y))
+
+    class Cat1(Module):
+        def forward(self, x, y):
+            return torch.cat((x, y), dim=1)
+
+    class Cat2(Module):
+        def forward(self, x, y):
+            return torch.cat((x, y), 1)
+
+    class Cat3(Module):
+        def forward(self, x, y):
+            return torch.concat((x, y), dim=0)
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+            inp_1: R.Tensor((2, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((4, 3), dtype="float32") = R.concat((inp_0, 
inp_1), axis=0)
+                gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+            inp_1: R.Tensor((2, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 6), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 6), dtype="float32") = R.concat((inp_0, 
inp_1), axis=1)
+                gv: R.Tuple(R.Tensor((2, 6), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, 
dtype=torch.float32))
+    verify_model(Cat0(), example_args, {}, Expected1)
+    verify_model(Cat1(), example_args, {}, Expected2)
+    verify_model(Cat2(), example_args, {}, Expected2)
+    verify_model(Cat3(), example_args, {}, Expected1)
+
+
+def test_cumsum():
+    class Cumsum(Module):
+        def forward(self, input):
+            return torch.cumsum(input, dim=1, dtype=torch.int32)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 2, 3, 4), dtype="int32") = R.cumsum(input_1, 
axis=1, dtype="int32")
+                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+    verify_model(Cumsum(), example_args, {}, expected1)
+
+
+def test_expand():
+    class Expand1(Module):
+        def forward(self, x):
+            return x.expand(4, 2, 3, 4)
+
+    class Expand2(Module):
+        def forward(self, x):
+            return x.expand(4, -1, -1, 4)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((4, 2, 3, 4), dtype="float32") = 
R.broadcast_to(x, (4, 2, 3, 4))
+                gv: R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+    verify_model(Expand1(), example_args, {}, expected1)
+    verify_model(Expand2(), example_args, {}, expected1)
+
+
+def test_flatten():
+    class Flatten(Module):
+        def __init__(self):
+            super().__init__()
+            self.f = torch.nn.Flatten(2, -1)
+
+        def forward(self, input):
+            return self.f(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 100), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 100), dtype="float32") = 
R.reshape(input_1, (1, 3, 100))
+                gv: R.Tuple(R.Tensor((1, 3, 100), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Flatten(), example_args, {}, expected1)
+
+
+def test_permute():
+    class Permute1(Module):
+        def forward(self, x):
+            return x.permute(0, 3, 2, 1)
+
+    class Permute2(Module):
+        def forward(self, x):
+            return torch.permute(x, (0, 3, 2, 1))
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 4, 3, 2), dtype="float32") = 
R.permute_dims(x, axes=[0, 3, 2, 1])
+                gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+    verify_model(Permute1(), example_args, {}, expected1)
+    verify_model(Permute2(), example_args, {}, expected1)
+
+
+def test_repeat():
+    class Tile1(Module):
+        def forward(self, x: torch.Tensor):
+            return x.repeat(2)
+
+    class Tile2(Module):
+        def forward(self, x: torch.Tensor):
+            return x.repeat(4, 2)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(x: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((6,), 
dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2)
+                gv: R.Tuple(R.Tensor((6,), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3), dtype="float32")
+        ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2])
+                gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(3, dtype=torch.float32),)
+    verify_model(Tile1(), example_args, {}, expected1)
+
+    example_args = (torch.randn(1, 3, dtype=torch.float32),)
+    verify_model(Tile2(), example_args, {}, expected2)
+
+    example_args = (torch.randn(1, 3, dtype=torch.float32),)
+    verify_model(Tile2(), example_args, {}, expected2)
+
+
+def test_reshape():
+    class Reshape(Module):
+        def forward(self, x):
+            return x.reshape(2, 12)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
+                gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+    verify_model(Reshape(), example_args, {}, expected1)
+
+
+def test_select_slice():
+    class Slice1(Module):
+        def forward(self, x):
+            return x[0, 1::2, :, :3]
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 10, 3), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((3, 10, 10), dtype="float32") = R.take(x, 
R.const(0, "int64"), axis=0)
+                lv1: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice(
+                    lv,
+                    (R.prim_value(0),),
+                    (R.prim_value(1),),
+                    (R.prim_value(9223372036854775807),),
+                    (R.prim_value(2),),
+                    assume_inbound=False,
+                )
+                lv2: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice(
+                    lv1,
+                    (R.prim_value(1),),
+                    (R.prim_value(0),),
+                    (R.prim_value(9223372036854775807),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv3: R.Tensor((1, 10, 3), dtype="float32") = R.strided_slice(
+                    lv2,
+                    (R.prim_value(2),),
+                    (R.prim_value(0),),
+                    (R.prim_value(3),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                gv: R.Tuple(R.Tensor((1, 10, 3), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    class Slice2(Module):
+        def forward(self, x):
+            return x[:, None, None, :, None]
+
+    @I.ir_module
+    class expected2:
+        @R.function
+        def main(
+            x: R.Tensor((8, 16), dtype="float32")
+        ) -> R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice(
+                    x,
+                    (R.prim_value(0),),
+                    (R.prim_value(0),),
+                    (R.prim_value(9223372036854775807),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv1: R.Tensor((8, 1, 16), dtype="float32") = R.expand_dims(lv, 
axis=[1])
+                lv2: R.Tensor((8, 1, 1, 16), dtype="float32") = 
R.expand_dims(lv1, axis=[2])
+                lv3: R.Tensor((8, 1, 1, 16), dtype="float32") = 
R.strided_slice(
+                    lv2,
+                    (R.prim_value(3),),
+                    (R.prim_value(0),),
+                    (R.prim_value(9223372036854775807),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv4: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = 
R.expand_dims(lv3, axis=[4])
+                gv: R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")) = 
(lv4,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Slice1(), example_args, {}, expected1)
+
+    example_args = (torch.randn(8, 16, dtype=torch.float32),)
+    verify_model(Slice2(), example_args, {}, expected2)
+
+
+def test_split():
+    class Chunk(Module):
+        def forward(self, input):
+            return torch.chunk(input, 3, dim=1)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(
+            R.Tensor((1, 1, 10, 10), dtype="float32"),
+            R.Tensor((1, 1, 10, 10), dtype="float32"),
+            R.Tensor((1, 1, 10, 10), dtype="float32"),
+        ):
+            # block 0
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                ) = R.split(input_1, indices_or_sections=3, axis=1)
+                lv1: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[0]
+                lv2: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[1]
+                lv3: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[2]
+                gv: R.Tuple(
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                ) = (lv1, lv2, lv3)
+                R.output(gv)
+            return gv
+
+    class Unbind1(Module):
+        def forward(self, data):
+            return torch.unbind(data)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(
+            R.Tensor((3, 10, 10), dtype="float32"),
+            R.Tensor((3, 10, 10), dtype="float32"),
+            R.Tensor((3, 10, 10), dtype="float32"),
+        ):
+            # block 0
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                    R.Tensor((0, 3, 10, 10), dtype="float32"),
+                ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0)
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
+                lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[0])
+                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1]
+                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, 
axis=[0])
+                lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2]
+                lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, 
axis=[0])
+                lv7: R.Tuple(
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                ) = (lv2, lv4, lv6)
+                lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
+                lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
+                lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+                gv: R.Tuple(
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                ) = (lv8, lv9, lv10)
+                R.output(gv)
+            return gv
+
+    class Unbind2(Module):
+        def forward(self, data):
+            return torch.unbind(data, dim=1)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(
+            R.Tensor((3, 10, 10), dtype="float32"),
+            R.Tensor((3, 10, 10), dtype="float32"),
+            R.Tensor((3, 10, 10), dtype="float32"),
+        ):
+            # block 0
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((3, 1, 10, 10), dtype="float32"),
+                    R.Tensor((3, 1, 10, 10), dtype="float32"),
+                    R.Tensor((3, 1, 10, 10), dtype="float32"),
+                    R.Tensor((3, 0, 10, 10), dtype="float32"),
+                ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1)
+                lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0]
+                lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[1])
+                lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1]
+                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, 
axis=[1])
+                lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2]
+                lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, 
axis=[1])
+                lv7: R.Tuple(
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                ) = (lv2, lv4, lv6)
+                lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
+                lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
+                lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+                gv: R.Tuple(
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                ) = (lv8, lv9, lv10)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Chunk(), example_args, {}, Expected)
+
+    example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Unbind1(), example_args, {}, expected1)
+    verify_model(Unbind2(), example_args, {}, expected2)
+
+
+def test_squeeze():
+    class Squeeze1(Module):
+        def forward(self, input):
+            return input.squeeze(1)
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
+        ) -> R.Tuple(R.Tensor((3, 4, 1), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, 
axis=[1])
+                gv: R.Tuple(R.Tensor((3, 4, 1), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class Squeeze2(Module):
+        def forward(self, input):
+            return input.squeeze()
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
+        ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, 
axis=None)
+                gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),)
+
+    verify_model(Squeeze1(), example_args, {}, Expected1)
+    verify_model(Squeeze2(), example_args, {}, Expected2)
+
+
+def test_tile():
+    class Tile1(Module):
+        def forward(self, x):
+            return x.tile((2,))
+
+    class Tile2(Module):
+        def forward(self, x):
+            return x.tile(4, 2)
+
+    class Tile3(Module):
+        def forward(self, x):
+            return torch.tile(x, (4, 2))
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2])
+                gv: R.Tuple(R.Tensor((1, 6), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3), dtype="float32")
+        ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2])
+                gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, dtype=torch.float32),)
+    verify_model(Tile1(), example_args, {}, expected1)
+    verify_model(Tile2(), example_args, {}, expected2)
+    verify_model(Tile3(), example_args, {}, expected2)
+
+
+def test_transpose():
+    class Transpose(Module):
+        def forward(self, x):
+            return x.transpose(1, 3)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 4, 3, 2), dtype="float32") = 
R.permute_dims(x, axes=[0, 3, 2, 1])
+                gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+    verify_model(Transpose(), example_args, {}, expected1)
+
+
+def test_unsqueeze():
+    class Unsqueeze1(Module):
+        def forward(self, input):
+            return input.unsqueeze(1)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = 
R.expand_dims(input_1, 1)
+                gv: R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")) = 
(lv,)
+                R.output(gv)
+            return gv
+
+    class Unsqueeze2(Module):
+        def forward(self, input):
+            return input.unsqueeze(-1)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = 
R.expand_dims(input_1, -1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")) = 
(lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+    verify_model(Unsqueeze1(), example_args, {}, expected1)
+    verify_model(Unsqueeze2(), example_args, {}, expected2)
+
+
 def test_view():
     class View(Module):
         def forward(self, x):
@@ -2756,6 +3332,211 @@ def test_view():
     verify_model(View(), example_args, {}, expected1)
 
 
+def test_arange():
+    class Arange(Module):
+        def forward(self, input):
+            return torch.arange(0, 20, dtype=torch.int32)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input: R.Tensor((10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((20,), dtype="int32")):
+            with R.dataflow():
+                lv: R.Tensor((20,), dtype="int32") = R.arange(0, 20, 1, 
dtype="int32")
+                gv: R.Tuple(R.Tensor((20,), dtype="int32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(10, 10, dtype=torch.float32),)
+    verify_model(Arange(), example_args, {}, Expected)
+
+
+def test_clone():
+    class Clone(Module):
+        def forward(self, input):
+            return torch.clone(input)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input: R.Tensor((10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            with R.dataflow():
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(10, 10, dtype=torch.float32),)
+    verify_model(Clone(), example_args, {}, Expected)
+
+
+def test_empty():
+    class Empty(Module):
+        def forward(self, input):
+            return torch.empty((10, 10), dtype=torch.float32)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.zeros(
+                    R.shape([10, 10]), dtype="float32"
+                )
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(10, 10, dtype=torch.float32),)
+    verify_model(Empty(), example_args, {}, Expected)
+
+
+def test_fill():
+    class Fill(Module):
+        def forward(self, input: torch.Tensor):
+            return torch.fill(input, 1.5)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.full(
+                    R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32"
+                )
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(10, 10, dtype=torch.float32),)
+    verify_model(Fill(), example_args, {}, Expected)
+
+
+def test_new_ones():
+    class NewOnes(Module):
+        def forward(self, x):
+            return x.new_ones(1, 2, 3)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 2, 3), dtype="float32") = R.full(
+                    (1, 2, 3), R.const(1, "float32"), dtype="float32"
+                )
+                gv: R.Tuple(R.Tensor((1, 2, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 2, 3, dtype=torch.float32),)
+    verify_model(NewOnes(), example_args, {}, expected1)
+
+
+def test_to_copy():
+    # float
+    class ToFloat(Module):
+        def forward(self, x):
+            return x.float()
+
+    @tvm.script.ir_module
+    class expected_float:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, 
dtype="float32")
+                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    # half
+    class ToHalf(Module):
+        def forward(self, x):
+            return x.half()
+
+    @tvm.script.ir_module
+    class expected_half:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x, 
dtype="float16")
+                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,)
+                R.output(gv)
+            return gv
+
+    # type
+    class Type(Module):
+        def forward(self, x):
+            return x.type(torch.float32)
+
+    @tvm.script.ir_module
+    class expected_type:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (x,)
+                R.output(gv)
+            return gv
+
+    class To1(Module):
+        def forward(self, input):
+            return input.to(torch.float16)
+
+    @I.ir_module
+    class expected_to1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")):
+            with R.dataflow():
+                lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(inp_0, 
dtype="float16")
+                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class To2(Module):
+        def forward(self, input):
+            return input.to("cpu")
+
+    @I.ir_module
+    class expected_to2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(inp_0, 
dtype="float32")
+                gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+    verify_model(ToFloat(), example_args, {}, expected_float)
+    verify_model(ToHalf(), example_args, {}, expected_half)
+    verify_model(Type(), example_args, {}, expected_type)
+    verify_model(To1(), example_args, {}, expected_to1)
+    verify_model(To2(), example_args, {}, expected_to2)
+
+
 def test_keep_params():
     class Conv2D1(Module):
         def __init__(self):


Reply via email to