This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fab67a9af9 [Relax][PyTorch] Support tensor manipulation and creation
ops for ExportedProgram importer (#17429)
fab67a9af9 is described below
commit fab67a9af918607542d8e6a895d53cc28030d7bd
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Wed Oct 2 09:33:01 2024 +0900
[Relax][PyTorch] Support tensor manipulation and creation ops for
ExportedProgram importer (#17429)
* support cat and concat
* support cumsum
* support expand
* support permute
* support squeeze
* support tile
* support transpose
* support unsqueeze
* add test for flatten
* support repeat
* add test for reshape
* support select and slice
* support arange
* support empty
* support fill
* support new_ones
* support _to_copy
* support split
* add test for unbind
* support clone
---
.../frontend/torch/base_fx_graph_translator.py | 161 +++++
.../frontend/torch/exported_program_translator.py | 39 +
python/tvm/relax/frontend/torch/fx_translator.py | 139 ----
.../relax/test_frontend_from_exported_program.py | 781 +++++++++++++++++++++
4 files changed, 981 insertions(+), 139 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 52784dc8c3..322ee04e0c 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -730,6 +730,51 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
########## Manipulation ##########
+ def _cat(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
+ return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
+
+ def _cumsum(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
None)
+ if "dtype" in node.kwargs:
+ dtype = self._convert_data_type(str(node.kwargs["dtype"]),
self.env)
+ else:
+ dtype = None
+ if "out" in node.kwargs:
+ raise ValueError("specifying out for cumsum is not supported yet")
+
+ return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
+
+ def _expand(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ sizes = args[1:] if len(args) > 2 else args[1]
+ broadcast_shape, in_shape = [], self.shape_of(args[0])
+ for idx, i in enumerate(sizes):
+ if isinstance(i, int) and i == -1:
+ broadcast_shape.append(in_shape[idx])
+ else:
+ broadcast_shape.append(i)
+ return self.block_builder.emit(relax.op.broadcast_to(args[0],
broadcast_shape))
+
+ def _permute(self, node: fx.Node) -> relax.Var:
+ import torch # type: ignore
+
+ args = self.retrieve_args(node)
+ x = args[0]
+ dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
+ return self.block_builder.emit(relax.op.permute_dims(x, dims))
+
+ def _repeat(self, node: fx.Node) -> relax.Var:
+ import torch # type: ignore
+
+ args = self.retrieve_args(node)
+ x = args[0]
+ dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
+ return self.block_builder.emit(relax.op.tile(x, dims))
+
def _reshape(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
@@ -738,6 +783,122 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
return self.block_builder.emit(relax.op.reshape(x, dims))
+ def _split(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ split_size = node.args[1]
+ dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
+ if isinstance(split_size, (list, tuple)):
+ n_section = []
+ for s in split_size[:-1]:
+ cum_sum = 0 if not n_section else n_section[-1]
+ n_section.append(s + cum_sum)
+ else:
+ n_section = (self.shape_of(x)[dim].value + split_size - 1) //
split_size
+ return self.block_builder.emit(relax.op.split(x, n_section, dim))
+
+ def _squeeze(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
None)
+ return self.block_builder.emit(relax.op.squeeze(x, dim))
+
+ def _tile(self, node: fx.Node) -> relax.Var:
+ import torch # type: ignore
+
+ args = self.retrieve_args(node)
+ x = args[0]
+ dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
+ return self.block_builder.emit(relax.op.tile(x, dims))
+
+ def _transpose(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ full_idx = list(range(len(self.shape_of(args[0]))))
+ full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]],
full_idx[args[1]]
+ return self.block_builder.emit(relax.op.permute_dims(args[0],
full_idx))
+
+ ########## Creation ##########
+
+ def _to_copy(self, node: fx.Node) -> relax.Var:
+ import torch # type: ignore
+
+ x = self.env[node.args[0]]
+ if len(node.args) == 2:
+ if isinstance(node.args[1], torch.dtype):
+ dtype = self._convert_data_type(node.args[1], self.env)
+ return self.block_builder.emit(relax.op.astype(x, dtype))
+ elif "dtype" in node.kwargs:
+ dtype = self._convert_data_type(node.kwargs["dtype"], self.env)
+ return self.block_builder.emit(relax.op.astype(x, dtype))
+ return x
+
+ def _arange(self, node: fx.Node) -> relax.Var:
+ import torch # type: ignore
+
+ start_end_step = [None, None, None]
+ if "start" in node.kwargs:
+ start_end_step[0] = node.kwargs["start"]
+ if "end" in node.kwargs:
+ start_end_step[1] = node.kwargs["end"]
+ if "step" in node.kwargs:
+ start_end_step[2] = node.kwargs["step"]
+
+ if len(node.args) == 1:
+ assert start_end_step[1] is None
+ start_end_step[1] = node.args[0]
+ elif len(node.args) == 2:
+ assert start_end_step[0] is None
+ assert start_end_step[1] is None
+ start_end_step[0] = node.args[0]
+ start_end_step[1] = node.args[1]
+ elif len(node.args) == 3:
+ assert start_end_step[0] is None
+ assert start_end_step[1] is None
+ assert start_end_step[2] is None
+ start_end_step[0] = node.args[0]
+ start_end_step[1] = node.args[1]
+ start_end_step[2] = node.args[2]
+
+ if start_end_step[0] is None:
+ start_end_step[0] = 0
+ if start_end_step[2] is None:
+ start_end_step[2] = 1
+
+ if "dtype" in node.kwargs:
+ dtype = self._convert_data_type(str(node.kwargs["dtype"]),
self.env)
+ elif any([isinstance(x, float) for x in start_end_step]):
+ dtype = self._convert_data_type(torch.get_default_dtype())
+ else:
+ dtype = "int64"
+ start_end_step = [
+ self.env[x] if isinstance(x, torch.fx.Node) else x for x in
start_end_step
+ ]
+ return self.block_builder.emit(relax.op.arange(*start_end_step,
dtype=dtype))
+
+ def _empty(self, node: fx.Node) -> relax.Var:
+ dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
+ return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))
+
+ def _fill(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dtype = x.struct_info.dtype
+ value = args[1] if isinstance(args[1], relax.Expr) else
relax.const(args[1], dtype)
+ return self.block_builder.emit(relax.op.full(x.struct_info.shape,
value, dtype))
+
+ def _new_ones(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ self_var = args[0]
+ size = args[1] if isinstance(args[1], (list, tuple)) else args[1:]
+ if not isinstance(size, (list, tuple)):
+ size = (size,)
+ size = relax.ShapeExpr(size)
+ return self.block_builder.emit(
+ relax.op.full(
+ size,
+ relax.const(1, self_var.struct_info.dtype),
+ self_var.struct_info.dtype,
+ )
+ )
+
########## Others ##########
def _getitem(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 64583d7509..1401a0bcef 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -162,6 +162,22 @@ class ExportedProgramImporter(BaseFXGraphImporter):
scale_factor = node.args[3] if len(node.args) > 3 else
node.kwargs.get("scale_factor", None)
return self._upsample_impl(x, size, align_corners, scale_factor,
"nearest_neighbor")
+ ########## Manipulation ##########
+
+ def _select(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1]
+ index = relax.const(node.args[2], "int64")
+ return self.block_builder.emit(relax.op.take(x, index, dim))
+
+ def _slice(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ axes = [node.args[1]]
+ begin = [node.args[2]]
+ end = [node.args[3]]
+ stride = [node.args[4] if len(node.args) > 4 else 1]
+ return self.block_builder.emit(relax.op.strided_slice(x, axes, begin,
end, stride))
+
def create_convert_map(
self,
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
@@ -249,7 +265,30 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"argmax.default": self._argmax_argmin(relax.op.argmax),
"argmin.default": self._argmax_argmin(relax.op.argmin),
# tensor manipulation
+ "cat.default": self._cat,
+ "concat.default": self._cat,
+ "cumsum.default": self._cumsum,
+ "expand.default": self._expand,
+ "permute.default": self._permute,
+ "repeat.default": self._repeat,
+ "select.int": self._select,
+ "slice.Tensor": self._slice,
+ "split.Tensor": self._split,
+ "squeeze.default": self._squeeze,
+ "squeeze.dim": self._squeeze,
+ "tile.default": self._tile,
+ "transpose.int": self._transpose,
+ "unsqueeze.default": lambda node: self.block_builder.emit(
+ relax.op.expand_dims(self.env[node.args[0]], node.args[1])
+ ),
"view.default": self._reshape,
+ # tensor creation
+ "_to_copy.default": self._to_copy,
+ "arange.start": self._arange,
+ "clone.default": lambda node: self.env[node.args[0]],
+ "empty.memory_format": self._empty,
+ "fill.Scalar": self._fill,
+ "new_ones.default": self._new_ones,
# other
"getitem": self._getitem,
}
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index c60c7c3953..9fbc95fa7c 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -380,41 +380,12 @@ class TorchFXImporter(BaseFXGraphImporter):
########## Manipulation ##########
- def _cat(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
- return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
-
def _chunk(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
chunks = node.args[1]
dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
return self.block_builder.emit(relax.op.split(x, chunks, dim))
- def _cumsum(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
-
- dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
None)
- if "dtype" in node.kwargs:
- dtype = self._convert_data_type(str(node.kwargs["dtype"]),
self.env)
- else:
- dtype = None
- if "out" in node.kwargs:
- raise ValueError("specifying out for cumsum is not supported yet")
-
- return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
-
- def _expand(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- sizes = args[1:] if len(args) > 2 else args[1]
- broadcast_shape, in_shape = [], self.shape_of(args[0])
- for idx, i in enumerate(sizes):
- if isinstance(i, int) and i == -1:
- broadcast_shape.append(in_shape[idx])
- else:
- broadcast_shape.append(i)
- return self.block_builder.emit(relax.op.broadcast_to(args[0],
broadcast_shape))
-
def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var:
shape = self.shape_of(x)
start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim
@@ -440,22 +411,6 @@ class TorchFXImporter(BaseFXGraphImporter):
end_dim = module.end_dim
return self._flatten_impl(x, start_dim, end_dim)
- def _permute(self, node: fx.Node) -> relax.Var:
- import torch # type: ignore
-
- args = self.retrieve_args(node)
- x = args[0]
- dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
- return self.block_builder.emit(relax.op.permute_dims(x, dims))
-
- def _repeat(self, node: fx.Node) -> relax.Var:
- import torch # type: ignore
-
- args = self.retrieve_args(node)
- x = args[0]
- dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
- return self.block_builder.emit(relax.op.tile(x, dims))
-
def _size(self, node: fx.Node) -> relax.Expr:
x = self.env[node.args[0]]
shape = self.shape_of(x)
@@ -466,87 +421,8 @@ class TorchFXImporter(BaseFXGraphImporter):
idx = node.args[1]
return self.shape_of(x)[idx].value
- def _split(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- split_size = node.args[1]
- dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
- if isinstance(split_size, (list, tuple)):
- n_section = []
- for s in split_size[:-1]:
- cum_sum = 0 if not n_section else n_section[-1]
- n_section.append(s + cum_sum)
- else:
- n_section = (self.shape_of(x)[dim].value + split_size - 1) //
split_size
- return self.block_builder.emit(relax.op.split(x, n_section, dim))
-
- def _squeeze(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
None)
- return self.block_builder.emit(relax.op.squeeze(x, dim))
-
- def _tile(self, node: fx.Node) -> relax.Var:
- import torch # type: ignore
-
- args = self.retrieve_args(node)
- x = args[0]
- dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
- return self.block_builder.emit(relax.op.tile(x, dims))
-
- def _transpose(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- full_idx = list(range(len(self.shape_of(args[0]))))
- full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]],
full_idx[args[1]]
- return self.block_builder.emit(relax.op.permute_dims(args[0],
full_idx))
-
########## Creation ##########
- def _arange(self, node: fx.Node) -> relax.Var:
- import torch # type: ignore
-
- start_end_step = [None, None, None]
- if "start" in node.kwargs:
- start_end_step[0] = node.kwargs["start"]
- if "end" in node.kwargs:
- start_end_step[1] = node.kwargs["end"]
- if "step" in node.kwargs:
- start_end_step[2] = node.kwargs["step"]
-
- if len(node.args) == 1:
- assert start_end_step[1] is None
- start_end_step[1] = node.args[0]
- elif len(node.args) == 2:
- assert start_end_step[0] is None
- assert start_end_step[1] is None
- start_end_step[0] = node.args[0]
- start_end_step[1] = node.args[1]
- elif len(node.args) == 3:
- assert start_end_step[0] is None
- assert start_end_step[1] is None
- assert start_end_step[2] is None
- start_end_step[0] = node.args[0]
- start_end_step[1] = node.args[1]
- start_end_step[2] = node.args[2]
-
- if start_end_step[0] is None:
- start_end_step[0] = 0
- if start_end_step[2] is None:
- start_end_step[2] = 1
-
- if "dtype" in node.kwargs:
- dtype = self._convert_data_type(str(node.kwargs["dtype"]),
self.env)
- elif any([isinstance(x, float) for x in start_end_step]):
- dtype = self._convert_data_type(torch.get_default_dtype())
- else:
- dtype = "int64"
- start_end_step = [
- self.env[x] if isinstance(x, torch.fx.Node) else x for x in
start_end_step
- ]
- return self.block_builder.emit(relax.op.arange(*start_end_step,
dtype=dtype))
-
- def _empty(self, node: fx.Node) -> relax.Var:
- dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
- return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))
-
def _inplace_fill(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
@@ -596,21 +472,6 @@ class TorchFXImporter(BaseFXGraphImporter):
values = self.block_builder.emit(relax.op.full_like(x, rx_value))
return self.block_builder.emit(relax.op.where(mask, values, x))
- def _new_ones(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- self_var = args[0]
- size = args[1] if isinstance(args[1], (list, tuple)) else args[1:]
- if not isinstance(size, (list, tuple)):
- size = (size,)
- size = relax.ShapeExpr(size)
- return self.block_builder.emit(
- relax.op.full(
- size,
- relax.const(1, self_var.struct_info.dtype),
- self_var.struct_info.dtype,
- )
- )
-
def _ones(self, node: fx.Node) -> relax.Var:
import torch
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 7c887d9b96..65890ff697 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2734,6 +2734,582 @@ def test_argmax_argmin():
verify_model(Argmin2(), example_args, {}, expected_argmin2)
+def test_cat_concat():
+ class Cat0(Module):
+ def forward(self, x, y):
+ return torch.cat((x, y))
+
+ class Cat1(Module):
+ def forward(self, x, y):
+ return torch.cat((x, y), dim=1)
+
+ class Cat2(Module):
+ def forward(self, x, y):
+ return torch.cat((x, y), 1)
+
+ class Cat3(Module):
+ def forward(self, x, y):
+ return torch.concat((x, y), dim=0)
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ inp_1: R.Tensor((2, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((4, 3), dtype="float32") = R.concat((inp_0,
inp_1), axis=0)
+ gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ inp_1: R.Tensor((2, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 6), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 6), dtype="float32") = R.concat((inp_0,
inp_1), axis=1)
+ gv: R.Tuple(R.Tensor((2, 6), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3,
dtype=torch.float32))
+ verify_model(Cat0(), example_args, {}, Expected1)
+ verify_model(Cat1(), example_args, {}, Expected2)
+ verify_model(Cat2(), example_args, {}, Expected2)
+ verify_model(Cat3(), example_args, {}, Expected1)
+
+
+def test_cumsum():
+ class Cumsum(Module):
+ def forward(self, input):
+ return torch.cumsum(input, dim=1, dtype=torch.int32)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 2, 3, 4), dtype="int32") = R.cumsum(input_1,
axis=1, dtype="int32")
+ gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+ verify_model(Cumsum(), example_args, {}, expected1)
+
+
+def test_expand():
+ class Expand1(Module):
+ def forward(self, x):
+ return x.expand(4, 2, 3, 4)
+
+ class Expand2(Module):
+ def forward(self, x):
+ return x.expand(4, -1, -1, 4)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((4, 2, 3, 4), dtype="float32") =
R.broadcast_to(x, (4, 2, 3, 4))
+ gv: R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+ verify_model(Expand1(), example_args, {}, expected1)
+ verify_model(Expand2(), example_args, {}, expected1)
+
+
+def test_flatten():
+ class Flatten(Module):
+ def __init__(self):
+ super().__init__()
+ self.f = torch.nn.Flatten(2, -1)
+
+ def forward(self, input):
+ return self.f(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 100), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 100), dtype="float32") =
R.reshape(input_1, (1, 3, 100))
+ gv: R.Tuple(R.Tensor((1, 3, 100), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Flatten(), example_args, {}, expected1)
+
+
+def test_permute():
+ class Permute1(Module):
+ def forward(self, x):
+ return x.permute(0, 3, 2, 1)
+
+ class Permute2(Module):
+ def forward(self, x):
+ return torch.permute(x, (0, 3, 2, 1))
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 4, 3, 2), dtype="float32") =
R.permute_dims(x, axes=[0, 3, 2, 1])
+ gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+ verify_model(Permute1(), example_args, {}, expected1)
+ verify_model(Permute2(), example_args, {}, expected1)
+
+
+def test_repeat():
+ class Tile1(Module):
+ def forward(self, x: torch.Tensor):
+ return x.repeat(2)
+
+ class Tile2(Module):
+ def forward(self, x: torch.Tensor):
+ return x.repeat(4, 2)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(x: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((6,),
dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2)
+ gv: R.Tuple(R.Tensor((6,), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3), dtype="float32")
+ ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2])
+ gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(3, dtype=torch.float32),)
+ verify_model(Tile1(), example_args, {}, expected1)
+
+ example_args = (torch.randn(1, 3, dtype=torch.float32),)
+ verify_model(Tile2(), example_args, {}, expected2)
+
+ example_args = (torch.randn(1, 3, dtype=torch.float32),)
+ verify_model(Tile2(), example_args, {}, expected2)
+
+
+def test_reshape():
+ class Reshape(Module):
+ def forward(self, x):
+ return x.reshape(2, 12)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
+ gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+ verify_model(Reshape(), example_args, {}, expected1)
+
+
+def test_select_slice():
+ class Slice1(Module):
+ def forward(self, x):
+ return x[0, 1::2, :, :3]
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 10, 3), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((3, 10, 10), dtype="float32") = R.take(x,
R.const(0, "int64"), axis=0)
+ lv1: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice(
+ lv,
+ (R.prim_value(0),),
+ (R.prim_value(1),),
+ (R.prim_value(9223372036854775807),),
+ (R.prim_value(2),),
+ assume_inbound=False,
+ )
+ lv2: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice(
+ lv1,
+ (R.prim_value(1),),
+ (R.prim_value(0),),
+ (R.prim_value(9223372036854775807),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv3: R.Tensor((1, 10, 3), dtype="float32") = R.strided_slice(
+ lv2,
+ (R.prim_value(2),),
+ (R.prim_value(0),),
+ (R.prim_value(3),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ gv: R.Tuple(R.Tensor((1, 10, 3), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ class Slice2(Module):
+ def forward(self, x):
+ return x[:, None, None, :, None]
+
+ @I.ir_module
+ class expected2:
+ @R.function
+ def main(
+ x: R.Tensor((8, 16), dtype="float32")
+ ) -> R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice(
+ x,
+ (R.prim_value(0),),
+ (R.prim_value(0),),
+ (R.prim_value(9223372036854775807),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv1: R.Tensor((8, 1, 16), dtype="float32") = R.expand_dims(lv,
axis=[1])
+ lv2: R.Tensor((8, 1, 1, 16), dtype="float32") =
R.expand_dims(lv1, axis=[2])
+ lv3: R.Tensor((8, 1, 1, 16), dtype="float32") =
R.strided_slice(
+ lv2,
+ (R.prim_value(3),),
+ (R.prim_value(0),),
+ (R.prim_value(9223372036854775807),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv4: R.Tensor((8, 1, 1, 16, 1), dtype="float32") =
R.expand_dims(lv3, axis=[4])
+ gv: R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")) =
(lv4,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Slice1(), example_args, {}, expected1)
+
+ example_args = (torch.randn(8, 16, dtype=torch.float32),)
+ verify_model(Slice2(), example_args, {}, expected2)
+
+
+def test_split():
+ class Chunk(Module):
+ def forward(self, input):
+ return torch.chunk(input, 3, dim=1)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ ):
+ # block 0
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ ) = R.split(input_1, indices_or_sections=3, axis=1)
+ lv1: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[0]
+ lv2: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[1]
+ lv3: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[2]
+ gv: R.Tuple(
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ ) = (lv1, lv2, lv3)
+ R.output(gv)
+ return gv
+
+ class Unbind1(Module):
+ def forward(self, data):
+ return torch.unbind(data)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ ):
+ # block 0
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((0, 3, 10, 10), dtype="float32"),
+ ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0)
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
+ lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1,
axis=[0])
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1]
+ lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3,
axis=[0])
+ lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2]
+ lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5,
axis=[0])
+ lv7: R.Tuple(
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ ) = (lv2, lv4, lv6)
+ lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
+ lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
+ lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+ gv: R.Tuple(
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ ) = (lv8, lv9, lv10)
+ R.output(gv)
+ return gv
+
+ class Unbind2(Module):
+ def forward(self, data):
+ return torch.unbind(data, dim=1)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ ):
+ # block 0
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((3, 1, 10, 10), dtype="float32"),
+ R.Tensor((3, 1, 10, 10), dtype="float32"),
+ R.Tensor((3, 1, 10, 10), dtype="float32"),
+ R.Tensor((3, 0, 10, 10), dtype="float32"),
+ ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1)
+ lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0]
+ lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1,
axis=[1])
+ lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1]
+ lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3,
axis=[1])
+ lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2]
+ lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5,
axis=[1])
+ lv7: R.Tuple(
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ ) = (lv2, lv4, lv6)
+ lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
+ lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
+ lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+ gv: R.Tuple(
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ ) = (lv8, lv9, lv10)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Chunk(), example_args, {}, Expected)
+
+ example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Unbind1(), example_args, {}, expected1)
+ verify_model(Unbind2(), example_args, {}, expected2)
+
+
+def test_squeeze():
+ class Squeeze1(Module):
+ def forward(self, input):
+ return input.squeeze(1)
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
+ ) -> R.Tuple(R.Tensor((3, 4, 1), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0,
axis=[1])
+ gv: R.Tuple(R.Tensor((3, 4, 1), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class Squeeze2(Module):
+ def forward(self, input):
+ return input.squeeze()
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
+ ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0,
axis=None)
+ gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),)
+
+ verify_model(Squeeze1(), example_args, {}, Expected1)
+ verify_model(Squeeze2(), example_args, {}, Expected2)
+
+
+def test_tile():
+ class Tile1(Module):
+ def forward(self, x):
+ return x.tile((2,))
+
+ class Tile2(Module):
+ def forward(self, x):
+ return x.tile(4, 2)
+
+ class Tile3(Module):
+ def forward(self, x):
+ return torch.tile(x, (4, 2))
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2])
+ gv: R.Tuple(R.Tensor((1, 6), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3), dtype="float32")
+ ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2])
+ gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, dtype=torch.float32),)
+ verify_model(Tile1(), example_args, {}, expected1)
+ verify_model(Tile2(), example_args, {}, expected2)
+ verify_model(Tile3(), example_args, {}, expected2)
+
+
+def test_transpose():
+ class Transpose(Module):
+ def forward(self, x):
+ return x.transpose(1, 3)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 4, 3, 2), dtype="float32") =
R.permute_dims(x, axes=[0, 3, 2, 1])
+ gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+ verify_model(Transpose(), example_args, {}, expected1)
+
+
+def test_unsqueeze():
+ class Unsqueeze1(Module):
+ def forward(self, input):
+ return input.unsqueeze(1)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") =
R.expand_dims(input_1, 1)
+ gv: R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")) =
(lv,)
+ R.output(gv)
+ return gv
+
+ class Unsqueeze2(Module):
+ def forward(self, input):
+ return input.unsqueeze(-1)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") =
R.expand_dims(input_1, -1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")) =
(lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+ verify_model(Unsqueeze1(), example_args, {}, expected1)
+ verify_model(Unsqueeze2(), example_args, {}, expected2)
+
+
def test_view():
class View(Module):
def forward(self, x):
@@ -2756,6 +3332,211 @@ def test_view():
verify_model(View(), example_args, {}, expected1)
+def test_arange():
+ class Arange(Module):
+ def forward(self, input):
+ return torch.arange(0, 20, dtype=torch.int32)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input: R.Tensor((10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((20,), dtype="int32")):
+ with R.dataflow():
+ lv: R.Tensor((20,), dtype="int32") = R.arange(0, 20, 1,
dtype="int32")
+ gv: R.Tuple(R.Tensor((20,), dtype="int32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(10, 10, dtype=torch.float32),)
+ verify_model(Arange(), example_args, {}, Expected)
+
+
+def test_clone():
+ class Clone(Module):
+ def forward(self, input):
+ return torch.clone(input)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input: R.Tensor((10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ with R.dataflow():
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(10, 10, dtype=torch.float32),)
+ verify_model(Clone(), example_args, {}, Expected)
+
+
+def test_empty():
+ class Empty(Module):
+ def forward(self, input):
+ return torch.empty((10, 10), dtype=torch.float32)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.zeros(
+ R.shape([10, 10]), dtype="float32"
+ )
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(10, 10, dtype=torch.float32),)
+ verify_model(Empty(), example_args, {}, Expected)
+
+
+def test_fill():
+ class Fill(Module):
+ def forward(self, input: torch.Tensor):
+ return torch.fill(input, 1.5)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.full(
+ R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32"
+ )
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(10, 10, dtype=torch.float32),)
+ verify_model(Fill(), example_args, {}, Expected)
+
+
+def test_new_ones():
+ class NewOnes(Module):
+ def forward(self, x):
+ return x.new_ones(1, 2, 3)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 2, 3), dtype="float32") = R.full(
+ (1, 2, 3), R.const(1, "float32"), dtype="float32"
+ )
+ gv: R.Tuple(R.Tensor((1, 2, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 2, 3, dtype=torch.float32),)
+ verify_model(NewOnes(), example_args, {}, expected1)
+
+
+def test_to_copy():
+ # float
+ class ToFloat(Module):
+ def forward(self, x):
+ return x.float()
+
+ @tvm.script.ir_module
+ class expected_float:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x,
dtype="float32")
+ gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ # half
+ class ToHalf(Module):
+ def forward(self, x):
+ return x.half()
+
+ @tvm.script.ir_module
+ class expected_half:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x,
dtype="float16")
+ gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,)
+ R.output(gv)
+ return gv
+
+ # type
+ class Type(Module):
+ def forward(self, x):
+ return x.type(torch.float32)
+
+ @tvm.script.ir_module
+ class expected_type:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (x,)
+ R.output(gv)
+ return gv
+
+ class To1(Module):
+ def forward(self, input):
+ return input.to(torch.float16)
+
+ @I.ir_module
+ class expected_to1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")):
+ with R.dataflow():
+ lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(inp_0,
dtype="float16")
+ gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class To2(Module):
+ def forward(self, input):
+ return input.to("cpu")
+
+ @I.ir_module
+ class expected_to2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(inp_0,
dtype="float32")
+ gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+ verify_model(ToFloat(), example_args, {}, expected_float)
+ verify_model(ToHalf(), example_args, {}, expected_half)
+ verify_model(Type(), example_args, {}, expected_type)
+ verify_model(To1(), example_args, {}, expected_to1)
+ verify_model(To2(), example_args, {}, expected_to2)
+
+
def test_keep_params():
class Conv2D1(Module):
def __init__(self):