This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4ab3f82669 [Relax][PyTorch] Cleanup Tensor Manipulation and Creation
op converters (#17376)
4ab3f82669 is described below
commit 4ab3f82669fb20d77cae47704c857ab39a577417
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Mon Sep 16 23:13:41 2024 +0900
[Relax][PyTorch] Cleanup Tensor Manipulation and Creation op converters
(#17376)
* cleanup `_cat()`
* cleanup `_cumsum()`
* cleanup `_expand()`
* cleanup `_flatten()`
* cleanup `_permute()`
* cleanup `_repeat()`
* cleanup `_reshape()`
* cleanup `_size()`
* cleanup `_split()`
* cleanup `_squeeze()`
* cleanup `_tile()`
* cleanup `_transpose()`
* cleanup `chunk()`
* cleanup `_arange()`
* cleanup `_empty()`
* cleanup `_inplace_fill()`
* cleanup `_full()`
* cleanup `_index_select()`
* cleanup `_inplace_masked_fill()`
* cleanup `_masked_fill()`
* cleanup `_new_ones()`
* cleanup `_ones()`
* cleanup `_tensor()`
* `_inplace_tril_triu()` is an unary op
* `_batch_norm_2d()` is a nn ops
* `_interpolate()` is a nn ops
* `_cross_entropy()` is a nn ops
* chore
* fix tensor size
---
python/tvm/relax/frontend/torch/fx_translator.py | 755 +++++++++++------------
1 file changed, 358 insertions(+), 397 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 4dc49d20ff..983bce0255 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -212,6 +212,20 @@ class TorchFXImporter:
assert dim is not None
return self.block_builder.emit(relax.op.nn.softmax(x, dim))
+ def _inplace_tril_triu(self, op: Callable) -> Callable:
+ from torch import fx
+
+ def convert(node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ k = node.args[1] if len(node.args) > 1 else 0
+ assert isinstance(k, int)
+
+ mutated = self.block_builder.emit(op(x, k))
+ self.env[node.args[0]] = mutated
+ return mutated
+
+ return convert
+
def _tril_triu(self, op: Callable) -> Callable:
from torch import fx
@@ -356,6 +370,29 @@ class TorchFXImporter:
res = bias if res is None else
self.block_builder.emit(relax.op.add(res, bias))
return res
+ def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ weight = self.params[module.weight]
+ bias = self.params[module.bias]
+ running_mean = self._convert_torch_tensor_to_relax(module.running_mean)
+ running_var = self._convert_torch_tensor_to_relax(module.running_var)
+ eps = module.eps
+
+ res_tuple = self.block_builder.emit(
+ relax.op.nn.batch_norm(
+ x,
+ weight,
+ bias,
+ running_mean,
+ running_var,
+ axis=1,
+ epsilon=eps,
+ )
+ )
+
+ return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0))
+
def _conv1d_transpose_impl(
self,
x: relax.Expr,
@@ -683,6 +720,40 @@ class TorchFXImporter:
groups=module.groups,
)
+ def _cross_entropy(self, node: fx.Node) -> relax.Expr:
+ preds = self.env[node.args[0]]
+ targets = self.env[node.args[1]]
+ weights = self.env.get(node.kwargs["weight"], None)
+ reduction = node.kwargs["reduction"]
+ ignore_index = node.kwargs["ignore_index"]
+
+ return self.block_builder.emit(
+ relax.op.nn.nll_loss(
+ relax.op.nn.log_softmax(preds), targets, weights, reduction,
ignore_index
+ )
+ )
+
+ def _cross_entropy_module(self, node: fx.Node) -> relax.Expr:
+ preds = self.env[node.args[0]]
+ targets = self.env[node.args[1]]
+ module = self.named_modules[node.target]
+
+ weights = module.weight
+ if weights is not None:
+ if weights in self.params:
+ weights = self.params[weights]
+ else:
+ weights = relax.const(weights.numpy(), preds.struct_info.dtype)
+
+ reduction = module.reduction
+ ignore_index = module.ignore_index
+
+ return self.block_builder.emit(
+ relax.op.nn.nll_loss(
+ relax.op.nn.log_softmax(preds), targets, weights, reduction,
ignore_index
+ )
+ )
+
def _einsum(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
@@ -740,6 +811,80 @@ class TorchFXImporter:
)
)
+ def _interpolate(self, node: fx.Node) -> relax.Var:
+ # torch.nn.functional.interpolate(
+ # input, size=None, scale_factor=None, mode='nearest',
align_corners=None,
+ # recompute_scale_factor=None, antialias=False)
+ # (TODO) this is a temporary implementation for interpolate that only
considers NCHW layout
+ # it basically replicates the implementation in
tvm.relay.frontend.pytorch
+ data = self.env[node.args[0]]
+ size = (
+ node.args[1]
+ if len(node.args) > 1
+ else (node.kwargs["size"] if "size" in node.kwargs else None)
+ )
+ scale_factor = (
+ node.args[2]
+ if len(node.args) > 2
+ else (node.kwargs["scale_factor"] if "scale_factor" in node.kwargs
else None)
+ )
+ method = (
+ node.args[3]
+ if len(node.args) > 3
+ else (node.kwargs["mode"] if "mode" in node.kwargs else "nearest")
+ )
+ align_corners = (
+ node.args[4]
+ if len(node.args) > 4
+ else (node.kwargs["align_corners"] if "align_corners" in
node.kwargs else None)
+ )
+ recompute_scale_factor = (
+ node.args[5]
+ if len(node.args) > 5
+ else (
+ node.kwargs["recompute_scale_factor"]
+ if "recompute_scale_factor" in node.kwargs
+ else None
+ )
+ )
+ antialias = (
+ node.args[6]
+ if len(node.args) > 6
+ else (node.kwargs["antialias"] if "antialias" in node.kwargs else
False)
+ )
+
+ assert recompute_scale_factor is None
+ assert antialias is False
+
+ if size is None:
+ shape = self.shape_of(data)
+ assert isinstance(shape, relax.ShapeExpr)
+ if isinstance(scale_factor, tuple):
+ assert len(scale_factor) == len(shape) - 2
+ size = tuple(
+ int(shape[i].value * scale_factor[i - 2]) for i in
range(2, len(shape))
+ )
+ else:
+ size = tuple(int(shape[i].value * scale_factor) for i in
range(2, len(shape)))
+
+ if method.startswith("nearest"):
+ method = "nearest_neighbor"
+ elif method[0:2] == "bi":
+ method = method[2:]
+
+ if method == "nearest_neighbor":
+ coord_trans = "asymmetric"
+ elif align_corners:
+ coord_trans = "align_corners"
+ else:
+ coord_trans = "half_pixel"
+
+ return self.block_builder.emit(
+ relax.op.image.resize2d(
+ data, size, layout="NCHW", method=method,
coordinate_transformation_mode=coord_trans
+ )
+ )
+
def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) ->
relax.Var:
from torch.fx.immutable_collections import immutable_list
import numpy as np # type: ignore
@@ -913,230 +1058,106 @@ class TorchFXImporter:
return convert
- ########## DataType ##########
-
- def _float(self, node: fx.Node) -> relax.Var:
- return self.block_builder.emit(relax.op.astype(self.env[node.args[0]],
"float32"))
-
- def _half(self, node: fx.Node) -> relax.Var:
- return self.block_builder.emit(relax.op.astype(self.env[node.args[0]],
"float16"))
+ ########## Manipulation ##########
- def _to(self, node: fx.Node) -> relax.Var:
- import torch
+ def _cat(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
+ return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
+ def _chunk(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
- if len(node.args) == 2:
- if isinstance(node.args[1], torch.dtype):
- dtype = TorchFXImporter._convert_data_type(node.args[1],
self.env)
- return self.block_builder.emit(relax.op.astype(x, dtype))
- elif "dtype" in node.kwargs:
- dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"],
self.env)
- return self.block_builder.emit(relax.op.astype(x, dtype))
- return x
+ chunks = node.args[1]
+ dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
+ return self.block_builder.emit(relax.op.split(x, chunks, dim))
- def _type(self, node: fx.Node) -> relax.Var:
+ def _cumsum(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
- dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
- return self.block_builder.emit(relax.op.astype(x, dtype))
- ########## Creation ##########
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
None)
+ if "dtype" in node.kwargs:
+ dtype = self._convert_data_type(str(node.kwargs["dtype"]),
self.env)
+ else:
+ dtype = None
+ if "out" in node.kwargs:
+ raise ValueError("specifying out for cumsum is not supported yet")
- def _arange(self, node: fx.Node) -> relax.Var:
- import torch
+ return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
- start_end_step = [None, None, None]
- if "start" in node.kwargs:
- start_end_step[0] = node.kwargs["start"]
- if "end" in node.kwargs:
- start_end_step[1] = node.kwargs["end"]
- if "step" in node.kwargs:
- start_end_step[2] = node.kwargs["step"]
+ def _expand(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ sizes = args[1:] if len(args) > 2 else args[1]
+ broadcast_shape, in_shape = [], self.shape_of(args[0])
+ for idx, i in enumerate(sizes):
+ if isinstance(i, int) and i == -1:
+ broadcast_shape.append(in_shape[idx])
+ else:
+ broadcast_shape.append(i)
+ return self.block_builder.emit(relax.op.broadcast_to(args[0],
broadcast_shape))
- if len(node.args) == 1:
- assert start_end_step[1] is None
- start_end_step[1] = node.args[0]
- elif len(node.args) == 2:
- assert start_end_step[0] is None
- assert start_end_step[1] is None
- start_end_step[0] = node.args[0]
- start_end_step[1] = node.args[1]
- elif len(node.args) == 3:
- assert start_end_step[0] is None
- assert start_end_step[1] is None
- assert start_end_step[2] is None
- start_end_step[0] = node.args[0]
- start_end_step[1] = node.args[1]
- start_end_step[2] = node.args[2]
+ def _flatten_impl(self, x, start_dim, end_dim) -> relax.Var:
+ shape = self.shape_of(x)
+ start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim
+ end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim
+ flattened = reduce(lambda x, y: x * y, [shape[i] for i in
range(start_dim, end_dim + 1)])
+ new_shape = (
+ [shape[i] for i in range(0, start_dim)]
+ + [flattened]
+ + [shape[i] for i in range(end_dim + 1, len(shape))]
+ )
+ return self.block_builder.emit(relax.op.reshape(x, new_shape))
- if start_end_step[0] is None:
- start_end_step[0] = 0
- if start_end_step[2] is None:
- start_end_step[2] = 1
+ def _flatten(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ start_dim = node.args[1] if len(node.args) >= 2 else
node.kwargs.get("start_dim", 0)
+ end_dim = node.args[2] if len(node.args) == 3 else
node.kwargs.get("end_dim", -1)
+ return self._flatten_impl(x, start_dim, end_dim)
- if "dtype" in node.kwargs:
- dtype =
TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env)
- elif any([isinstance(x, float) for x in start_end_step]):
- dtype =
TorchFXImporter._convert_data_type(torch.get_default_dtype())
- else:
- dtype = "int64"
- start_end_step = [
- self.env[x] if isinstance(x, torch.fx.Node) else x for x in
start_end_step
- ]
- return self.block_builder.emit(relax.op.arange(*start_end_step,
dtype=dtype))
+ def _flatten_module(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ start_dim = module.start_dim
+ end_dim = module.end_dim
+ return self._flatten_impl(x, start_dim, end_dim)
- def _empty(self, node: fx.Node) -> relax.Var:
- dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]),
self.env)
- return self.block_builder.emit(relax.op.zeros(node.args, dtype))
+ def _permute(self, node: fx.Node) -> relax.Var:
+ import torch # type: ignore
- def _inplace_fill(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
- dtype = x.struct_info.dtype
- value = args[1] if isinstance(args[1], relax.Expr) else
relax.const(args[1], dtype)
- filled = self.block_builder.emit(relax.op.full(x.struct_info.shape,
value, dtype))
- self.env[node.args[0]] = filled
- return filled
+ dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
+ return self.block_builder.emit(relax.op.permute_dims(x, dims))
- def _tensor(self, node: fx.Node) -> relax.Var:
- dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None
- if isinstance(node.args[0], float):
- return relax.const(node.args[0], dtype if dtype is not None else
"float32")
- elif isinstance(node.args[0], int):
- return relax.const(node.args[0], dtype if dtype is not None else
"int64")
- raise ValueError("torch.tensor with value not a float or int is not
accepted")
+ def _repeat(self, node: fx.Node) -> relax.Var:
+ import torch # type: ignore
- def _inplace_tril_triu(self, op: Callable) -> Callable:
- from torch import fx
+ args = self.retrieve_args(node)
+ x = args[0]
+ dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
+ return self.block_builder.emit(relax.op.tile(x, dims))
- def convert(node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- k = node.args[1] if len(node.args) > 1 else 0
- assert isinstance(k, int)
-
- mutated = self.block_builder.emit(op(x, k))
- self.env[node.args[0]] = mutated
- return mutated
-
- return convert
-
- def _new_ones(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- self_var = args[0]
- size = args[1:]
- if not isinstance(size, (list, tuple)):
- size = (size,)
- size = relax.ShapeExpr(size)
- return self.block_builder.emit(
- relax.op.full(
- size,
- relax.const(1, self_var.struct_info.dtype),
- self_var.struct_info.dtype,
- )
- )
-
- def _ones(self, node: fx.Node) -> relax.Var:
- import torch
-
- args = self.retrieve_args(node)
- size = args[0]
- if not isinstance(size, (list, tuple)):
- size = (size,)
- size = relax.ShapeExpr(size)
- dtype = (
- TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]),
self.env)
- if "dtype" in node.kwargs
- else TorchFXImporter._convert_data_type(torch.get_default_dtype(),
self.env)
- )
- return self.block_builder.emit(
- relax.op.full(
- size,
- relax.const(1, dtype),
- dtype,
- )
- )
-
- def _full(self, node: fx.Node) -> relax.Var:
- import torch
-
- args = self.retrieve_args(node)
- size = args[0]
- if not isinstance(size, (list, tuple)):
- size = (size,)
- size = relax.ShapeExpr(size)
- dtype = (
- TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]),
self.env)
- if "dtype" in node.kwargs
- else TorchFXImporter._convert_data_type(torch.get_default_dtype(),
self.env)
- )
- value = args[1] if isinstance(args[1], relax.expr.Constant) else
relax.const(args[1], dtype)
- return self.block_builder.emit(
- relax.op.full(
- size,
- value,
- dtype,
- )
- )
-
- ########## Manipulation ##########
-
- def _cat(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
- return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
+ def _reshape(self, node: fx.Node) -> relax.Var:
+ import torch # type: ignore
- def _expand(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
- broadcast_shape, in_shape = [], self.shape_of(args[0])
- for idx, i in enumerate(args[1:]):
- if isinstance(i, int) and i == -1:
- broadcast_shape.append(in_shape[idx])
- else:
- broadcast_shape.append(i)
- return self.block_builder.emit(relax.op.broadcast_to(args[0],
broadcast_shape))
+ x = args[0]
+ dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
+ return self.block_builder.emit(relax.op.reshape(x, dims))
- def _flatten(self, node: fx.Node) -> relax.Var:
+ def _size(self, node: fx.Node) -> relax.Expr:
x = self.env[node.args[0]]
- if node.target in self.named_modules:
- module = self.named_modules[node.target]
- start_dim = module.start_dim
- end_dim = module.end_dim
- else:
- start_dim = node.args[1] if len(node.args) >= 2 else 0
- end_dim = node.args[2] if len(node.args) == 3 else -1
shape = self.shape_of(x)
- start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim
- end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim
- flattened = reduce(lambda x, y: x * y, [shape[i] for i in
range(start_dim, end_dim + 1)])
- new_shape = (
- [shape[i] for i in range(0, start_dim)]
- + [flattened]
- + [shape[i] for i in range(end_dim + 1, len(shape))]
- )
- return self.block_builder.emit(relax.op.reshape(x, new_shape))
-
- def _permute(self, node: fx.Node) -> relax.Var:
- import torch # type: ignore
-
- args = self.retrieve_args(node)
- if isinstance(args[1], (torch.Size, tuple, list)):
- return self.block_builder.emit(relax.op.permute_dims(args[0],
tuple(args[1])))
- return self.block_builder.emit(relax.op.permute_dims(args[0],
args[1:]))
-
- def _reshape(self, node: fx.Node) -> relax.Var:
- import torch # type: ignore
-
- args = self.retrieve_args(node)
- if isinstance(args[1], (torch.Size, tuple, list)):
- return self.block_builder.emit(relax.op.reshape(args[0],
tuple(args[1])))
- return self.block_builder.emit(relax.op.reshape(args[0], args[1:]))
+ if len(node.args) == 1:
+ assert isinstance(shape, relax.ShapeExpr)
+ return shape
+ assert len(node.args) == 2
+ idx = node.args[1]
+ return self.shape_of(x)[idx].value
def _split(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
split_size = node.args[1]
- if "dim" in node.kwargs:
- dim = node.kwargs["dim"]
- else:
- dim = 0
+ dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0)
if isinstance(split_size, (list, tuple)):
n_section = []
for s in split_size[:-1]:
@@ -1146,17 +1167,18 @@ class TorchFXImporter:
n_section = (self.shape_of(x)[dim].value + split_size - 1) //
split_size
return self.block_builder.emit(relax.op.split(x, n_section, dim))
- def _chunk(self, node: fx.Node) -> relax.Var:
+ def _squeeze(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
- chunks = node.args[1]
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
None)
+ return self.block_builder.emit(relax.op.squeeze(x, dim))
- if "dim" in node.kwargs:
- dim = node.kwargs["dim"]
- elif len(node.args) > 2:
- dim = node.args[2]
- else:
- dim = 0
- return self.block_builder.emit(relax.op.split(x, chunks, dim))
+ def _tile(self, node: fx.Node) -> relax.Var:
+ import torch # type: ignore
+
+ args = self.retrieve_args(node)
+ x = args[0]
+ dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
+ return self.block_builder.emit(relax.op.tile(x, dims))
def _transpose(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
@@ -1164,50 +1186,80 @@ class TorchFXImporter:
full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]],
full_idx[args[1]]
return self.block_builder.emit(relax.op.permute_dims(args[0],
full_idx))
- def _squeeze(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
-
- if "dim" in node.kwargs:
- dim = node.kwargs["dim"]
- elif len(node.args) > 1:
- dim = node.args[1]
- else:
- dim = None
- return self.block_builder.emit(relax.op.squeeze(x, dim))
+ ########## Creation ##########
- def _repeat(self, node: fx.Node) -> relax.Var:
+ def _arange(self, node: fx.Node) -> relax.Var:
import torch # type: ignore
- args = self.retrieve_args(node)
- if isinstance(args[1], (torch.Size, tuple, list)):
- return self.block_builder.emit(relax.op.tile(args[0],
tuple(args[1])))
- return self.block_builder.emit(relax.op.tile(args[0], args[1:]))
-
- def _tile(self, node: fx.Node) -> relax.Var:
- import torch # type: ignore
+ start_end_step = [None, None, None]
+ if "start" in node.kwargs:
+ start_end_step[0] = node.kwargs["start"]
+ if "end" in node.kwargs:
+ start_end_step[1] = node.kwargs["end"]
+ if "step" in node.kwargs:
+ start_end_step[2] = node.kwargs["step"]
- args = self.retrieve_args(node)
- if isinstance(args[1], (torch.Size, tuple, list)):
- return self.block_builder.emit(relax.op.tile(args[0],
tuple(args[1])))
- return self.block_builder.emit(relax.op.tile(args[0], args[1:]))
+ if len(node.args) == 1:
+ assert start_end_step[1] is None
+ start_end_step[1] = node.args[0]
+ elif len(node.args) == 2:
+ assert start_end_step[0] is None
+ assert start_end_step[1] is None
+ start_end_step[0] = node.args[0]
+ start_end_step[1] = node.args[1]
+ elif len(node.args) == 3:
+ assert start_end_step[0] is None
+ assert start_end_step[1] is None
+ assert start_end_step[2] is None
+ start_end_step[0] = node.args[0]
+ start_end_step[1] = node.args[1]
+ start_end_step[2] = node.args[2]
- def _cumsum(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
+ if start_end_step[0] is None:
+ start_end_step[0] = 0
+ if start_end_step[2] is None:
+ start_end_step[2] = 1
- if "dim" in node.kwargs:
- dim = node.kwargs["dim"]
- elif len(node.args) > 1:
- dim = node.args[1]
- else:
- dim = None
if "dtype" in node.kwargs:
- dtype =
TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env)
+ dtype = self._convert_data_type(str(node.kwargs["dtype"]),
self.env)
+ elif any([isinstance(x, float) for x in start_end_step]):
+ dtype = self._convert_data_type(torch.get_default_dtype())
else:
- dtype = None
- if "out" in node.kwargs:
- raise ValueError("specifying out for cumsum is not supported yet")
+ dtype = "int64"
+ start_end_step = [
+ self.env[x] if isinstance(x, torch.fx.Node) else x for x in
start_end_step
+ ]
+ return self.block_builder.emit(relax.op.arange(*start_end_step,
dtype=dtype))
- return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
+ def _empty(self, node: fx.Node) -> relax.Var:
+ dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
+ return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))
+
+ def _inplace_fill(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dtype = x.struct_info.dtype
+ value = args[1] if isinstance(args[1], relax.Expr) else
relax.const(args[1], dtype)
+ filled = self.block_builder.emit(relax.op.full(x.struct_info.shape,
value, dtype))
+ self.env[node.args[0]] = filled
+ return filled
+
+ def _full(self, node: fx.Node) -> relax.Var:
+ import torch
+
+ args = self.retrieve_args(node)
+ size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple))
else (args[0],))
+ dtype = self._convert_data_type(
+ node.kwargs.get("dtype", torch.get_default_dtype()), self.env
+ )
+ value = args[1] if isinstance(args[1], relax.expr.Constant) else
relax.const(args[1], dtype)
+ return self.block_builder.emit(
+ relax.op.full(
+ size,
+ value,
+ dtype,
+ )
+ )
def _index_select(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
@@ -1215,14 +1267,6 @@ class TorchFXImporter:
index = self.env[node.args[2]]
return self.block_builder.emit(relax.op.take(x, index, dim))
- def _masked_fill(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- mask = self.env[node.args[1]]
- value = node.args[2]
- rx_value = relax.const(value)
- values = self.block_builder.emit(relax.op.full_like(x, rx_value))
- return self.block_builder.emit(relax.op.where(mask, values, x))
-
def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
mask = self.env[node.args[1]]
@@ -1233,168 +1277,79 @@ class TorchFXImporter:
self.env[node.args[0]] = output
return output
- ########## Neural Network ##########
-
- def _softmax(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- if node.target in self.named_modules:
- module = self.named_modules[node.target]
- dim = module.dim
- else:
- nargs = len(node.args)
- dim = node.args[1] if nargs > 1 else node.kwargs["dim"]
- assert dim is not None
- return self.block_builder.emit(relax.op.nn.softmax(x, dim))
-
- def _batch_norm_2d(self, node: fx.Node) -> relax.Var:
+ def _masked_fill(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
- module = self.named_modules[node.target]
- weight = self.params[module.weight]
- bias = self.params[module.bias]
- running_mean = self._convert_torch_tensor_to_relax(module.running_mean)
- running_var = self._convert_torch_tensor_to_relax(module.running_var)
- eps = module.eps
+ mask = self.env[node.args[1]]
+ rx_value = relax.const(node.args[2])
+ values = self.block_builder.emit(relax.op.full_like(x, rx_value))
+ return self.block_builder.emit(relax.op.where(mask, values, x))
- res_tuple = self.block_builder.emit(
- relax.op.nn.batch_norm(
- x,
- weight,
- bias,
- running_mean,
- running_var,
- axis=1,
- epsilon=eps,
+ def _new_ones(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ self_var = args[0]
+ size = args[1] if isinstance(args[1], (list, tuple)) else args[1:]
+ if not isinstance(size, (list, tuple)):
+ size = (size,)
+ size = relax.ShapeExpr(size)
+ return self.block_builder.emit(
+ relax.op.full(
+ size,
+ relax.const(1, self_var.struct_info.dtype),
+ self_var.struct_info.dtype,
)
)
- return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0))
+ def _ones(self, node: fx.Node) -> relax.Var:
+ import torch
- def _interpolate(self, node: fx.Node) -> relax.Var:
- # torch.nn.functional.interpolate(
- # input, size=None, scale_factor=None, mode='nearest',
align_corners=None,
- # recompute_scale_factor=None, antialias=False)
- # (TODO) this is a temporary implementation for interpolate that only
considers NCHW layout
- # it basically replicates the implementation in
tvm.relay.frontend.pytorch
- data = self.env[node.args[0]]
- size = (
- node.args[1]
- if len(node.args) > 1
- else (node.kwargs["size"] if "size" in node.kwargs else None)
- )
- scale_factor = (
- node.args[2]
- if len(node.args) > 2
- else (node.kwargs["scale_factor"] if "scale_factor" in node.kwargs
else None)
- )
- method = (
- node.args[3]
- if len(node.args) > 3
- else (node.kwargs["mode"] if "mode" in node.kwargs else "nearest")
- )
- align_corners = (
- node.args[4]
- if len(node.args) > 4
- else (node.kwargs["align_corners"] if "align_corners" in
node.kwargs else None)
- )
- recompute_scale_factor = (
- node.args[5]
- if len(node.args) > 5
- else (
- node.kwargs["recompute_scale_factor"]
- if "recompute_scale_factor" in node.kwargs
- else None
- )
- )
- antialias = (
- node.args[6]
- if len(node.args) > 6
- else (node.kwargs["antialias"] if "antialias" in node.kwargs else
False)
+ args = self.retrieve_args(node)
+ size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple))
else (args[0],))
+ dtype = self._convert_data_type(
+ node.kwargs.get("dtype", torch.get_default_dtype()), self.env
)
-
- assert recompute_scale_factor is None
- assert antialias is False
-
- if size is None:
- shape = self.shape_of(data)
- assert isinstance(shape, relax.ShapeExpr)
- if isinstance(scale_factor, tuple):
- assert len(scale_factor) == len(shape) - 2
- size = tuple(
- int(shape[i].value * scale_factor[i - 2]) for i in
range(2, len(shape))
- )
- else:
- size = tuple(int(shape[i].value * scale_factor) for i in
range(2, len(shape)))
-
- if method.startswith("nearest"):
- method = "nearest_neighbor"
- elif method[0:2] == "bi":
- method = method[2:]
-
- if method == "nearest_neighbor":
- coord_trans = "asymmetric"
- elif align_corners:
- coord_trans = "align_corners"
- else:
- coord_trans = "half_pixel"
-
return self.block_builder.emit(
- relax.op.image.resize2d(
- data, size, layout="NCHW", method=method,
coordinate_transformation_mode=coord_trans
+ relax.op.full(
+ size,
+ relax.const(1, dtype),
+ dtype,
)
)
- def _cross_entropy(self, node: fx.Node) -> relax.Expr:
- preds = self.env[node.args[0]]
- targets = self.env[node.args[1]]
-
- # functional.cross_entropy
- if node.target not in self.named_modules:
- weights = node.kwargs["weight"]
- if weights is not None:
- weights = self.env[weights]
- reduction = node.kwargs["reduction"]
- ignore_index = node.kwargs["ignore_index"]
-
- return self.block_builder.emit(
- relax.op.nn.nll_loss(
- relax.op.nn.log_softmax(preds), targets, weights,
reduction, ignore_index
- )
- )
+ def _tensor(self, node: fx.Node) -> relax.Var:
+ dtype = node.kwargs.get("dtype", None)
+ if isinstance(node.args[0], float):
+ return relax.const(node.args[0], dtype if dtype is not None else
"float32")
+ elif isinstance(node.args[0], int):
+ return relax.const(node.args[0], dtype if dtype is not None else
"int64")
+ raise ValueError("torch.tensor with value not a float or int is not
accepted")
- module = self.named_modules[node.target]
+ ########## DataType ##########
- weights = module.weight
- if weights is not None:
- if weights in self.params:
- weights = self.params[weights]
- else:
- weights = relax.const(weights.numpy(), preds.struct_info.dtype)
- reduction = module.reduction
- ignore_index = module.ignore_index
+ def _float(self, node: fx.Node) -> relax.Var:
+ return self.block_builder.emit(relax.op.astype(self.env[node.args[0]],
"float32"))
- return self.block_builder.emit(
- relax.op.nn.nll_loss(
- relax.op.nn.log_softmax(preds), targets, weights, reduction,
ignore_index
- )
- )
+ def _half(self, node: fx.Node) -> relax.Var:
+ return self.block_builder.emit(relax.op.astype(self.env[node.args[0]],
"float16"))
- ########## Others ##########
+ def _to(self, node: fx.Node) -> relax.Var:
+ import torch
- def _sym_size_int(self, node: fx.Node) -> relax.Expr:
x = self.env[node.args[0]]
- shape = self.shape_of(x)
- idx = node.args[1]
- return self.block_builder.emit(relax.const(shape[idx].value, "int32"))
+ if len(node.args) == 2:
+ if isinstance(node.args[1], torch.dtype):
+ dtype = TorchFXImporter._convert_data_type(node.args[1],
self.env)
+ return self.block_builder.emit(relax.op.astype(x, dtype))
+ elif "dtype" in node.kwargs:
+ dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"],
self.env)
+ return self.block_builder.emit(relax.op.astype(x, dtype))
+ return x
- def _size(self, node: fx.Node) -> relax.Expr:
+ def _type(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
- shape = self.shape_of(x)
- if len(node.args) == 1:
- assert isinstance(shape, relax.ShapeExpr)
- return shape
- assert len(node.args) == 2
- idx = node.args[1]
- return self.shape_of(x)[idx].value
+ dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
+ return self.block_builder.emit(relax.op.astype(x, dtype))
+
+ ########## Others ##########
def _getattr(self, node: fx.Node) -> relax.Var:
if isinstance(self.env[node.args[0]], relax.Expr):
@@ -1485,6 +1440,12 @@ class TorchFXImporter:
else:
assert False
+ def _sym_size_int(self, node: fx.Node) -> relax.Expr:
+ x = self.env[node.args[0]]
+ shape = self.shape_of(x)
+ idx = node.args[1]
+ return self.block_builder.emit(relax.const(shape[idx].value, "int32"))
+
def create_convert_map(self):
import operator
from torch import nn
@@ -1511,20 +1472,20 @@ class TorchFXImporter:
# neural network
nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module,
nn.AvgPool2d: self._avg_pool2d_module,
- nn.BatchNorm2d: self._batch_norm_2d,
+ nn.BatchNorm2d: self._batch_norm_2d_module,
nn.Conv1d: self._conv1d_module,
nn.Conv2d: self._conv2d_module,
nn.Conv3d: self._conv3d_module,
nn.ConvTranspose1d: self._conv1d_transpose_module,
nn.ConvTranspose2d: self._conv2d_transpose_module,
- nn.CrossEntropyLoss: self._cross_entropy,
+ nn.CrossEntropyLoss: self._cross_entropy_module,
nn.GroupNorm: self._group_norm_module,
nn.LayerNorm: self._layer_norm_module,
nn.Linear: self._linear_module,
nn.MaxPool2d: self._max_pool2d_module,
nn.modules.sparse.Embedding: self._embedding_module,
# tensor manipulation
- nn.Flatten: self._flatten,
+ nn.Flatten: self._flatten_module,
## call_function and call_method
# unary
"acos": self._unary_op(relax.op.acos),
@@ -1603,6 +1564,7 @@ class TorchFXImporter:
"argmin": self._argmax_argmin(relax.op.argmin),
# tensor manipulation
"cat": self._cat,
+ "chunk": self._chunk,
"concat": self._cat,
"contiguous": lambda node: self.env[node.args[0]],
"cumsum": self._cumsum,
@@ -1622,7 +1584,6 @@ class TorchFXImporter:
"view": self._reshape,
# tensor creation
"arange": self._arange,
- "chunk": self._chunk,
"empty": self._empty,
"fill_": self._inplace_fill,
"full": self._full,
@@ -1632,11 +1593,11 @@ class TorchFXImporter:
"new_ones": self._new_ones,
"ones": self._ones,
"tensor": self._tensor,
- "to": self._to,
# datatype
"astype": self._type,
"float": self._float,
"half": self._half,
+ "to": self._to,
"type": self._type,
# other
"getattr": self._getattr,