This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d9ee6377cd [Relax][PyTorch] Support neural network ops for
ExportedProgram importer (#17426)
d9ee6377cd is described below
commit d9ee6377cdd8395b27385d2fc2745b741fad6183
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Sep 29 06:59:33 2024 +0900
[Relax][PyTorch] Support neural network ops for ExportedProgram importer
(#17426)
* support batchnorm2d and getitem
* support addmm
* support avg_pool2d
* support baddbmm
* support bmm
* support conv_transpose1d
* support conv_transpose2d
* support conv1d
* support conv3d
* support einsum
* support embedding
* support group_norm
* support layer_norm
* support scaled_dot_product_attention
* support unbind
* support interpolate
* fix lint error
---
.../frontend/torch/base_fx_graph_translator.py | 464 ++++++++
.../frontend/torch/exported_program_translator.py | 111 ++
python/tvm/relax/frontend/torch/fx_translator.py | 482 +-------
.../relax/test_frontend_from_exported_program.py | 1150 +++++++++++++++++++-
4 files changed, 1723 insertions(+), 484 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index a41b9b6d4f..52784dc8c3 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -227,6 +227,228 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
)
+ def _addmm(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ y = self.env[node.args[1]]
+ z = self.env[node.args[2]]
+ alpha = node.kwargs.get("alpha", 1)
+ beta = node.kwargs.get("beta", 1)
+
+ res = None
+ if alpha != 0:
+ res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z,
out_dtype="float32"))
+ if alpha != 1:
+ dtype = res.struct_info.dtype
+ res = self.block_builder.emit(relax.op.multiply(res,
relax.const(alpha, dtype)))
+ if beta != 0:
+ dtype = x.struct_info.dtype
+ if beta != 1:
+ bias = self.block_builder.emit(relax.op.multiply(x,
relax.const(beta, dtype)))
+ else:
+ bias = x
+ res = bias if res is None else
self.block_builder.emit(relax.op.add(bias, res))
+ return res
+
+ def _avg_pool2d_impl(
+ self,
+ x: relax.Expr,
+ kernel_size: Union[int, Tuple[int, int]] = (1, 1),
+ stride: Optional[Union[int, Tuple[int, int]]] = None,
+ padding: Optional[int] = 0,
+ ceil_mode: Optional[bool] = False,
+ ) -> relax.Var:
+ stride = kernel_size if stride is None or stride == [] else stride
+ return self.block_builder.emit(
+ relax.op.nn.avg_pool2d(
+ x,
+ pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ ceil_mode=ceil_mode,
+ layout="NCHW",
+ )
+ )
+
+ def _avg_pool2d(self, node: fx.Node) -> relax.Var:
+ args, kwargs = node.normalized_arguments(node)
+ x = self.env[args[0]]
+ kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"]
+ stride = args[2] if len(args) > 2 else kwargs.get("stride", None)
+ padding = args[3] if len(args) > 3 else kwargs.get("padding", 0)
+ ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode",
False)
+ return self._avg_pool2d_impl(x, kernel_size, stride, padding,
ceil_mode)
+
+ def _baddbmm(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ batch1 = self.env[node.args[1]]
+ batch2 = self.env[node.args[2]]
+ alpha = node.kwargs.get("alpha", 1)
+ beta = node.kwargs.get("beta", 1)
+
+ res = None
+ if alpha != 0:
+ res = self.block_builder.emit(relax.op.matmul(batch1, batch2))
+ if alpha != 1:
+ dtype = res.struct_info.dtype
+ res = self.block_builder.emit(relax.op.multiply(res,
relax.const(alpha, dtype)))
+ if beta != 0:
+ dtype = x.struct_info.dtype
+ if beta != 1:
+ bias = self.block_builder.emit(relax.op.multiply(x,
relax.const(beta, dtype)))
+ else:
+ bias = x
+ res = bias if res is None else
self.block_builder.emit(relax.op.add(res, bias))
+ return res
+
+ def _conv_transpose1d_impl(
+ self,
+ x: relax.Expr,
+ weight: relax.Expr,
+ bias: Optional[relax.Expr],
+ strides: Optional[Tuple],
+ padding: Optional[Tuple],
+ dilation: Optional[Tuple],
+ groups: Optional[Tuple],
+ ) -> relax.Var:
+ conv1d_transpose = self.block_builder.emit(
+ relax.op.nn.conv1d_transpose(
+ x,
+ weight,
+ strides=strides,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ data_layout="NCW",
+ kernel_layout="OIW",
+ out_dtype="float32",
+ )
+ )
+
+ if bias is None:
+ return conv1d_transpose
+
+ assert len(self.shape_of(bias)) == 1
+ bias = relax.op.reshape(bias, (1, -1, 1))
+ return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))
+
+ def _conv_transpose1d(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ weight = args[1]
+ bias = args[2] if len(args) > 2 else None
+ stride = args[3] if len(args) > 3 else 1
+ padding = args[4] if len(args) > 4 else 0
+ dilation = args[5] if len(args) > 5 else 1
+ groups = args[6] if len(args) > 6 else 1
+ return self._conv_transpose1d_impl(
+ x,
+ weight,
+ bias=bias,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+
+ def _conv_transpose2d_impl(
+ self,
+ x: relax.Expr,
+ weight: relax.Expr,
+ bias: Optional[relax.Expr],
+ strides: Optional[Tuple],
+ padding: Optional[Tuple],
+ dilation: Optional[Tuple],
+ groups: Optional[Tuple],
+ ) -> relax.Var:
+ conv2d_transpose = self.block_builder.emit(
+ relax.op.nn.conv2d_transpose(
+ x,
+ weight,
+ strides=strides,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_dtype="float32",
+ )
+ )
+
+ if bias is None:
+ return conv2d_transpose
+
+ assert len(self.shape_of(bias)) == 1
+ bias = relax.op.reshape(bias, (1, -1, 1, 1))
+ return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))
+
+ def _conv_transpose2d(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ weight = args[1]
+ bias = args[2] if len(args) > 2 else None
+ stride = args[3] if len(args) > 3 else 1
+ padding = args[4] if len(args) > 4 else 0
+ dilation = args[5] if len(args) > 5 else 1
+ groups = args[6] if len(args) > 6 else 1
+ return self._conv_transpose2d_impl(
+ x,
+ weight,
+ bias=bias,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+
+ def _conv1d_impl(
+ self,
+ x: relax.Expr,
+ weight: relax.Expr,
+ bias: Optional[relax.Expr],
+ strides: Optional[Tuple],
+ padding: Optional[Tuple],
+ dilation: Optional[Tuple],
+ groups: Optional[Tuple],
+ ) -> relax.Var:
+ conv1d = self.block_builder.emit(
+ relax.op.nn.conv1d(
+ x,
+ weight,
+ strides=strides,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ data_layout="NCW",
+ kernel_layout="OIW",
+ out_dtype="float32",
+ )
+ )
+
+ if bias is None:
+ return conv1d
+ assert len(self.shape_of(bias)) == 1
+ bias = relax.op.reshape(bias, (1, -1, 1))
+ return self.block_builder.emit(relax.op.add(conv1d, bias))
+
+ def _conv1d(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ weight = args[1]
+ bias = args[2] if len(args) > 2 else None
+ stride = args[3] if len(args) > 3 else 1
+ padding = args[4] if len(args) > 4 else 0
+ dilation = args[5] if len(args) > 5 else 1
+ groups = args[6] if len(args) > 6 else 1
+ return self._conv1d_impl(
+ x,
+ weight,
+ bias=bias,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+
def _conv2d_impl(
self,
x: relax.Expr,
@@ -276,6 +498,134 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
groups=groups,
)
+ def _conv3d_impl(
+ self,
+ x: relax.Expr,
+ weight: relax.Expr,
+ bias: Optional[relax.Expr],
+ strides: Optional[Tuple],
+ padding: Optional[Tuple],
+ dilation: Optional[Tuple],
+ groups: Optional[Tuple],
+ ):
+ conv3d = self.block_builder.emit(
+ relax.op.nn.conv3d(
+ x,
+ weight,
+ strides=strides,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ data_layout="NCDHW",
+ kernel_layout="OIDHW",
+ out_dtype="float32",
+ )
+ )
+
+ if bias is None:
+ return conv3d
+ assert len(self.shape_of(bias)) == 1
+ bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
+ return self.block_builder.emit(relax.op.add(conv3d, bias))
+
+ def _conv3d(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ weight = args[1]
+ bias = args[2] if len(args) > 2 else None
+ stride = args[3] if len(args) > 3 else 1
+ padding = args[4] if len(args) > 4 else 0
+ dilation = args[5] if len(args) > 5 else 1
+ groups = args[6] if len(args) > 6 else 1
+ return self._conv3d_impl(
+ x,
+ weight,
+ bias=bias,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+
+ def _einsum(self, node: fx.Node) -> relax.Var:
+ import torch # type: ignore
+
+ args = self.retrieve_args(node)
+ operands = args[1] if isinstance(args[1], (torch.Size, tuple, list))
else args[1:]
+ return self.block_builder.emit(relax.op.einsum(operands, args[0]))
+
+ def _embedding_impl(
+ self,
+ x,
+ weight,
+ ) -> relax.Var:
+ x = self.block_builder.emit(relax.op.astype(x, "int32"))
+
+ ndim = x.struct_info.ndim
+ if ndim == 1:
+ return self.block_builder.emit(relax.op.take(weight, x, axis=0))
+ else:
+ x_shape = x.struct_info.shape.values
+ emb_size = weight.struct_info.shape.values[-1]
+ x = self.block_builder.emit(relax.op.reshape(x, shape=[-1]))
+ embedding = self.block_builder.emit(relax.op.take(weight, x,
axis=0))
+ return self.block_builder.emit(relax.op.reshape(embedding,
[*x_shape, emb_size]))
+
+ def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) ->
relax.Var:
+ from torch.fx.immutable_collections import immutable_list
+ import numpy as np # type: ignore
+
+ if isinstance(normalized_shape, (immutable_list, tuple)):
+ normalized_shape = tuple(normalized_shape)
+ else:
+ try:
+ normalized_shape = self.env[normalized_shape]
+ except TypeError:
+ normalized_shape = tuple(normalized_shape)
+
+ dim_num = len(normalized_shape)
+ axes = list(range(-dim_num, 0))
+
+ if gamma is None:
+ shape_tuple = [int(s) for s in normalized_shape]
+ gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype)
+ if beta is None:
+ shape_tuple = [int(s) for s in normalized_shape]
+ beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype)
+
+ return self.block_builder.emit(
+ relax.op.nn.layer_norm(
+ x,
+ gamma,
+ beta,
+ axes=axes,
+ epsilon=eps,
+ )
+ )
+
+ def _layer_norm(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ normalized_shape = node.args[1]
+ gamma = self.env[node.args[2]] if len(node.args) > 2 else None
+ beta = self.env[node.args[3]] if len(node.args) > 3 else None
+ eps = node.args[4] if len(node.args) > 4 else 1e-05
+ return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape)
+
+ def _layer_norm_module(self, node: fx.Node) -> relax.Var:
+ import torch # type: ignore
+
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ normalized_shape = module.normalized_shape
+ if module.elementwise_affine:
+ gamma = self.params[module.weight]
+ beta = self.params[module.bias]
+ else:
+ gamma = relax.const(torch.ones_like(module.normalized_shape),
x.struct_info.dtype)
+ beta = relax.const(torch.zeros_like(module.normalized_shape),
x.struct_info.dtype)
+ eps = module.eps
+ return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape)
+
def _linear(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
@@ -316,6 +666,39 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self._max_pool2d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
+ transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1,
3])
+ query = transpose_S_H(self.env[node.args[0]])
+ key = transpose_S_H(self.env[node.args[1]])
+ value = transpose_S_H(self.env[node.args[2]])
+ attn_mask = node.args[3] if len(node.args) > 3 else
node.kwargs.get("attn_mask", None)
+ dropout_p = node.args[4] if len(node.args) > 4 else
node.kwargs.get("dropout_p", 0.0)
+ assert dropout_p == 0.0, "Dropout is not supported"
+ is_causal = node.args[5] if len(node.args) > 5 else
node.kwargs.get("is_causal", False)
+ causal_mask = "TopLeft" if is_causal else None
+
+ if attn_mask is not None:
+ attn_mask = self.env[attn_mask]
+ msg = "Only a float mask is supported for the attn_mask input."
+ assert "float" in attn_mask.struct_info.dtype, msg
+
+ return self.block_builder.emit(
+ transpose_S_H(
+ relax.op.nn.attention(query, key, value, bias=attn_mask,
causal_mask=causal_mask)
+ )
+ )
+
+ def _unbind(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
+ assert isinstance(dim, int), "Expected 2nd argument of unbind as int"
+ selections = self.shape_of(x)[dim].value
+ n_section = list(range(1, selections + 1))
+ ret, split = [], self.block_builder.emit(relax.op.split(x, n_section,
dim))
+ for i in range(selections):
+ ret.append(self.block_builder.emit(relax.op.squeeze(split[i],
axis=dim)))
+ return self.block_builder.emit(relax.Tuple(ret))
+
########## Statistical ##########
def _mean(self, node: fx.Node) -> relax.Var:
@@ -357,6 +740,87 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
########## Others ##########
+ def _getitem(self, node: fx.Node) -> relax.Var:
+ import torch
+
+ x = self.env[node.args[0]]
+ if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)):
+ return x[node.args[1]]
+ elif isinstance(x, relax.Var):
+ if isinstance(x.struct_info, relax.TupleStructInfo):
+ return self.block_builder.emit(relax.TupleGetItem(x,
node.args[1]))
+
+ assert isinstance(x.struct_info, relax.TensorStructInfo)
+ take_indices = []
+ take_axes = []
+ stride_begin = []
+ stride_end = []
+ stride = []
+ stride_axes = []
+ expand_dim = []
+ i = 0
+ shape = self.shape_of(x)
+ non_ellipsis_cnt = 0
+ for index in node.args[1]:
+ if isinstance(index, (int, slice, torch.fx.Node)):
+ non_ellipsis_cnt += 1
+ for index in node.args[1]:
+ if isinstance(index, int):
+ stride_begin.append(index)
+ stride_end.append(index + 1)
+ stride.append(1)
+ stride_axes.append(i)
+ i = i + 1
+ elif isinstance(index, slice):
+ stride_begin.append(0 if index.start is None else
index.start)
+ stride_end.append(shape[i] if index.stop is None else
index.stop)
+ stride.append(1 if index.step is None else index.step)
+ stride_axes.append(i)
+ i = i + 1
+ elif index is None:
+ expand_dim.append(len(stride_axes) + len(expand_dim))
+ elif index is Ellipsis:
+ for _ in range(len(shape) - non_ellipsis_cnt):
+ stride_begin.append(0)
+ stride_end.append(shape[i])
+ stride.append(1)
+ stride_axes.append(i)
+ i += 1
+ elif isinstance(index, torch.fx.Node):
+ node_index = self.env[index]
+ if not isinstance(node_index, relax.Expr):
+ raise ValueError(
+ "Unsupported index type for relax.op.take: " +
str(type(node_index))
+ )
+ take_indices.append(node_index)
+ take_axes.append(i)
+ i = i + 1
+ else:
+ raise ValueError("Unsupported index type: " +
str(type(index)))
+ while i < len(shape):
+ stride_begin.append(0)
+ stride_end.append(shape[i])
+ stride.append(1)
+ stride_axes.append(i)
+ i += 1
+ taken = x
+ if len(take_indices) > 1:
+ raise ValueError("Multiple tensors as index not yet supported")
+ for each_index, each_axis in zip(take_indices, take_axes):
+ taken = self.block_builder.emit(relax.op.take(taken,
each_index, each_axis))
+ sliced = self.block_builder.emit(
+ relax.op.strided_slice(taken, stride_axes, stride_begin,
stride_end, stride)
+ )
+ sliced_shape = list(self.shape_of(sliced))
+ for i in expand_dim:
+ sliced_shape.insert(i, 1)
+ return self.block_builder.emit(relax.op.reshape(sliced,
sliced_shape))
+ elif isinstance(x, relax.Constant):
+ dtype = x.struct_info.dtype
+ return relax.const(x.data.numpy()[node.args[1]], dtype)
+ else:
+ assert False
+
@abc.abstractmethod
def create_convert_map(
self,
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 11594690cd..64583d7509 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -74,6 +74,94 @@ class ExportedProgramImporter(BaseFXGraphImporter):
max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val",
1.0)
return self.block_builder.emit(relax.op.clip(x, min_val, max_val))
+ ########## Neural Network ##########
+
+ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
+ import numpy as np
+
+ x = self.env[node.args[0]]
+ channel = int(self.shape_of(x)[1])
+ dtype = x.struct_info.dtype
+ weight = self.env.get(node.args[1], relax.const(np.ones(channel),
dtype=dtype))
+ bias = self.env.get(node.args[2], relax.const(np.zeros(channel),
dtype=dtype))
+ running_mean = self.env.get(node.args[3],
relax.const(np.zeros(channel), dtype=dtype))
+ running_var = self.env.get(node.args[4], relax.const(np.ones(channel),
dtype=dtype))
+ momentum = node.args[5] if len(node.args) > 5 else
node.kwargs.get("momentum", 0.1)
+ eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps",
1e-05)
+
+ return self.block_builder.emit(
+ relax.op.nn.batch_norm(
+ x,
+ weight,
+ bias,
+ running_mean,
+ running_var,
+ axis=1,
+ epsilon=eps,
+ momentum=momentum,
+ )
+ )
+
+ def _group_norm(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ num_groups = node.args[1]
+ gamma = self.env[node.args[2]] if len(node.args) > 2 else None
+ beta = self.env[node.args[3]] if len(node.args) > 3 else None
+ eps = node.args[4] if len(node.args) > 4 else 1e-05
+
+ dim = len(self.shape_of(x))
+ return self.block_builder.emit(
+ relax.op.nn.group_norm(
+ x,
+ gamma,
+ beta,
+ num_groups=num_groups,
+ channel_axis=1,
+ axes=list(range(2, dim)),
+ epsilon=eps,
+ )
+ )
+
+ def _upsample_impl(
+ self, x: relax.Expr, size, align_corners: bool, scale_factor, method:
str
+ ) -> relax.Var:
+ coord_trans = "align_corners" if align_corners else "half_pixel"
+
+ if size is None:
+ shape = self.shape_of(x)
+ assert isinstance(shape, relax.ShapeExpr)
+ if isinstance(scale_factor, (tuple, list)):
+ assert len(scale_factor) == len(shape) - 2
+ size = tuple(
+ int(shape[i].value * scale_factor[i - 2]) for i in
range(2, len(shape))
+ )
+ else:
+ size = tuple(int(shape[i].value * scale_factor) for i in
range(2, len(shape)))
+
+ return self.block_builder.emit(
+ relax.op.image.resize2d(
+ x, size, layout="NCHW", method=method,
coordinate_transformation_mode=coord_trans
+ )
+ )
+
+ def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size",
None)
+ align_corners = (
+ node.args[2] if len(node.args) > 2 else
node.kwargs.get("align_corners", True)
+ )
+ scale_factor = node.args[3] if len(node.args) > 3 else
node.kwargs.get("scale_factor", None)
+ return self._upsample_impl(x, size, align_corners, scale_factor,
"linear")
+
+ def _upsample_nearest2d(self, node: fx.node) -> relax.Var:
+ x = self.env[node.args[0]]
+ size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size",
None)
+ align_corners = (
+ node.args[2] if len(node.args) > 2 else
node.kwargs.get("align_corners", True)
+ )
+ scale_factor = node.args[3] if len(node.args) > 3 else
node.kwargs.get("scale_factor", None)
+ return self._upsample_impl(x, size, align_corners, scale_factor,
"nearest_neighbor")
+
def create_convert_map(
self,
) -> Dict[str, Callable[[fx.Node], relax.Var]]:
@@ -129,10 +217,31 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
"sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
# neural network
+ "_native_batch_norm_legit_no_training.default":
self._batch_norm_legit_no_training,
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
+ "addmm.default": self._addmm,
+ "avg_pool2d.default": self._avg_pool2d,
+ "baddbmm.default": self._baddbmm,
+ "bmm.default": self._binary_op(
+ partial(relax.op.linear_algebra.matmul, out_dtype="float32"),
operator.matmul
+ ),
+ "conv_transpose1d.default": self._conv_transpose1d,
+ "conv_transpose2d.input": self._conv_transpose2d,
+ "conv1d.default": self._conv1d,
"conv2d.default": self._conv2d,
+ "conv3d.default": self._conv3d,
+ "einsum.default": self._einsum,
+ "embedding.default": lambda node: self._embedding_impl(
+ self.env[node.args[1]], self.env[node.args[0]]
+ ),
+ "group_norm.default": self._group_norm,
+ "layer_norm.default": self._layer_norm,
"linear.default": self._linear,
"max_pool2d.default": self._max_pool2d,
+ "scaled_dot_product_attention.default":
self._scaled_dot_product_attention,
+ "unbind.int": self._unbind,
+ "upsample_bilinear2d.vec": self._upsample_bilinear2d,
+ "upsample_nearest2d.vec": self._upsample_nearest2d,
# statistical
"mean.dim": self._mean,
"sum.dim_IntList": self._sum,
@@ -141,6 +250,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"argmin.default": self._argmax_argmin(relax.op.argmin),
# tensor manipulation
"view.default": self._reshape,
+ # other
+ "getitem": self._getitem,
}
def from_exported_program(
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index dc6ebc2eb3..c60c7c3953 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -18,7 +18,7 @@
# pylint: disable=invalid-name, inconsistent-return-statements,
unidiomatic-typecheck
# pylint: disable=import-outside-toplevel
"""PyTorch FX frontend of Relax."""
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Tuple, Union
from functools import partial, reduce
import tvm
@@ -107,57 +107,6 @@ class TorchFXImporter(BaseFXGraphImporter):
relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
)
- def _addmm(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- y = self.env[node.args[1]]
- z = self.env[node.args[2]]
- alpha = node.kwargs.get("alpha", 1)
- beta = node.kwargs.get("beta", 1)
-
- res = None
- if alpha != 0:
- res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z,
out_dtype="float32"))
- if alpha != 1:
- dtype = res.struct_info.dtype
- res = self.block_builder.emit(relax.op.multiply(res,
relax.const(alpha, dtype)))
- if beta != 0:
- dtype = x.struct_info.dtype
- if beta != 1:
- bias = self.block_builder.emit(relax.op.multiply(x,
relax.const(beta, dtype)))
- else:
- bias = x
- res = bias if res is None else
self.block_builder.emit(relax.op.add(bias, res))
- return res
-
- def _avg_pool2d_impl(
- self,
- x: relax.Expr,
- kernel_size: Union[int, Tuple[int, int]] = (1, 1),
- stride: Optional[Union[int, Tuple[int, int]]] = None,
- padding: Optional[int] = 0,
- ceil_mode: Optional[bool] = False,
- ) -> relax.Var:
- stride = kernel_size if stride is None or stride == [] else stride
- return self.block_builder.emit(
- relax.op.nn.avg_pool2d(
- x,
- pool_size=kernel_size,
- strides=stride,
- padding=padding,
- ceil_mode=ceil_mode,
- layout="NCHW",
- )
- )
-
- def _avg_pool2d(self, node: fx.Node) -> relax.Var:
- args, kwargs = node.normalized_arguments(node)
- x = self.env[args[0]]
- kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"]
- stride = args[2] if len(args) > 2 else kwargs.get("stride", None)
- padding = args[3] if len(args) > 3 else kwargs.get("padding", 0)
- ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode",
False)
- return self._avg_pool2d_impl(x, kernel_size, stride, padding,
ceil_mode)
-
def _avg_pool2d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -167,28 +116,6 @@ class TorchFXImporter(BaseFXGraphImporter):
ceil_mode = module.ceil_mode
return self._avg_pool2d_impl(x, kernel_size, stride, padding,
ceil_mode)
- def _baddbmm(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- a = self.env[node.args[1]]
- b = self.env[node.args[2]]
- alpha = node.kwargs.get("alpha", 1)
- beta = node.kwargs.get("beta", 1)
-
- res = None
- if alpha != 0:
- res = self.block_builder.emit(relax.op.matmul(a, b))
- if alpha != 1:
- dtype = res.struct_info.dtype
- res = self.block_builder.emit(relax.op.multiply(res,
relax.const(alpha, dtype)))
- if beta != 0:
- dtype = x.struct_info.dtype
- if beta != 1:
- bias = self.block_builder.emit(relax.op.multiply(x,
relax.const(beta, dtype)))
- else:
- bias = x
- res = bias if res is None else
self.block_builder.emit(relax.op.add(res, bias))
- return res
-
def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -212,63 +139,13 @@ class TorchFXImporter(BaseFXGraphImporter):
return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0))
- def _conv1d_transpose_impl(
- self,
- x: relax.Expr,
- weight: relax.Expr,
- bias: Optional[relax.Expr],
- strides: Optional[Tuple],
- padding: Optional[Tuple],
- dilation: Optional[Tuple],
- groups: Optional[Tuple],
- ) -> relax.Var:
- conv1d_transpose = self.block_builder.emit(
- relax.op.nn.conv1d_transpose(
- x,
- weight,
- strides=strides,
- padding=padding,
- dilation=dilation,
- groups=groups,
- data_layout="NCW",
- kernel_layout="OIW",
- out_dtype="float32",
- )
- )
-
- if bias is None:
- return conv1d_transpose
-
- assert len(self.shape_of(bias)) == 1
- bias = relax.op.reshape(bias, (1, -1, 1))
- return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))
-
- def _conv1d_transpose(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- weight = args[1]
- bias = args[2] if len(args) > 2 else None
- stride = args[3] if len(args) > 3 else 1
- padding = args[4] if len(args) > 4 else 0
- dilation = args[5] if len(args) > 5 else 1
- groups = args[6] if len(args) > 6 else 1
- return self._conv1d_transpose_impl(
- x,
- weight,
- bias=bias,
- strides=stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- )
-
- def _conv1d_transpose_module(self, node: fx.Node) -> relax.Var:
+ def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = self.params.get(module.bias, None)
- return self._conv1d_transpose_impl(
+ return self._conv_transpose1d_impl(
x,
weight,
bias=bias,
@@ -278,63 +155,13 @@ class TorchFXImporter(BaseFXGraphImporter):
groups=module.groups,
)
- def _conv2d_transpose_impl(
- self,
- x: relax.Expr,
- weight: relax.Expr,
- bias: Optional[relax.Expr],
- strides: Optional[Tuple],
- padding: Optional[Tuple],
- dilation: Optional[Tuple],
- groups: Optional[Tuple],
- ) -> relax.Var:
- conv2d_transpose = self.block_builder.emit(
- relax.op.nn.conv2d_transpose(
- x,
- weight,
- strides=strides,
- padding=padding,
- dilation=dilation,
- groups=groups,
- data_layout="NCHW",
- kernel_layout="OIHW",
- out_dtype="float32",
- )
- )
-
- if bias is None:
- return conv2d_transpose
-
- assert len(self.shape_of(bias)) == 1
- bias = relax.op.reshape(bias, (1, -1, 1, 1))
- return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))
-
- def _conv2d_transpose(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- weight = args[1]
- bias = args[2] if len(args) > 2 else None
- stride = args[3] if len(args) > 3 else 1
- padding = args[4] if len(args) > 4 else 0
- dilation = args[5] if len(args) > 5 else 1
- groups = args[6] if len(args) > 6 else 1
- return self._conv2d_transpose_impl(
- x,
- weight,
- bias=bias,
- strides=stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- )
-
- def _conv2d_transpose_module(self, node: fx.Node) -> relax.Var:
+ def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
weight = self.params[module.weight]
bias = self.params.get(module.bias, None)
- return self._conv2d_transpose_impl(
+ return self._conv_transpose2d_impl(
x,
weight,
bias=bias,
@@ -344,55 +171,6 @@ class TorchFXImporter(BaseFXGraphImporter):
groups=module.groups,
)
- def _conv1d_impl(
- self,
- x: relax.Expr,
- weight: relax.Expr,
- bias: Optional[relax.Expr],
- strides: Optional[Tuple],
- padding: Optional[Tuple],
- dilation: Optional[Tuple],
- groups: Optional[Tuple],
- ) -> relax.Var:
- conv1d = self.block_builder.emit(
- relax.op.nn.conv1d(
- x,
- weight,
- strides=strides,
- padding=padding,
- dilation=dilation,
- groups=groups,
- data_layout="NCW",
- kernel_layout="OIW",
- out_dtype="float32",
- )
- )
-
- if bias is None:
- return conv1d
- assert len(self.shape_of(bias)) == 1
- bias = relax.op.reshape(bias, (1, -1, 1))
- return self.block_builder.emit(relax.op.add(conv1d, bias))
-
- def _conv1d(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- weight = args[1]
- bias = args[2] if len(args) > 2 else None
- stride = args[3] if len(args) > 3 else 1
- padding = args[4] if len(args) > 4 else 0
- dilation = args[5] if len(args) > 5 else 1
- groups = args[6] if len(args) > 6 else 1
- return self._conv1d_impl(
- x,
- weight,
- bias=bias,
- strides=stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- )
-
def _conv1d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -425,55 +203,6 @@ class TorchFXImporter(BaseFXGraphImporter):
groups=module.groups,
)
- def _conv3d_impl(
- self,
- x: relax.Expr,
- weight: relax.Expr,
- bias: Optional[relax.Expr],
- strides: Optional[Tuple],
- padding: Optional[Tuple],
- dilation: Optional[Tuple],
- groups: Optional[Tuple],
- ):
- conv3d = self.block_builder.emit(
- relax.op.nn.conv3d(
- x,
- weight,
- strides=strides,
- padding=padding,
- dilation=dilation,
- groups=groups,
- data_layout="NCDHW",
- kernel_layout="OIDHW",
- out_dtype="float32",
- )
- )
-
- if bias is None:
- return conv3d
- assert len(self.shape_of(bias)) == 1
- bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
- return self.block_builder.emit(relax.op.add(conv3d, bias))
-
- def _conv3d(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- weight = args[1]
- bias = args[2] if len(args) > 2 else None
- stride = args[3] if len(args) > 3 else 1
- padding = args[4] if len(args) > 4 else 0
- dilation = args[5] if len(args) > 5 else 1
- groups = args[6] if len(args) > 6 else 1
- return self._conv3d_impl(
- x,
- weight,
- bias=bias,
- strides=stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- )
-
def _conv3d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -524,30 +253,6 @@ class TorchFXImporter(BaseFXGraphImporter):
)
)
- def _einsum(self, node: fx.Node) -> relax.Var:
- import torch # type: ignore
-
- args = self.retrieve_args(node)
- operands = args[1] if isinstance(args[1], (torch.Size, tuple, list))
else args[1:]
- return self.block_builder.emit(relax.op.einsum(operands, args[0]))
-
- def _embedding_impl(
- self,
- x,
- weight,
- ) -> relax.Var:
- x = self.block_builder.emit(relax.op.astype(x, "int32"))
-
- ndim = x.struct_info.ndim
- if ndim == 1:
- return self.block_builder.emit(relax.op.take(weight, x, axis=0))
- else:
- x_shape = x.struct_info.shape.values
- emb_size = weight.struct_info.shape.values[-1]
- x = self.block_builder.emit(relax.op.reshape(x, shape=[-1]))
- embedding = self.block_builder.emit(relax.op.take(weight, x,
axis=0))
- return self.block_builder.emit(relax.op.reshape(embedding,
[*x_shape, emb_size]))
-
def _embedding_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -655,61 +360,6 @@ class TorchFXImporter(BaseFXGraphImporter):
)
)
- def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) ->
relax.Var:
- from torch.fx.immutable_collections import immutable_list
- import numpy as np # type: ignore
-
- if isinstance(normalized_shape, (immutable_list, tuple)):
- normalized_shape = tuple(normalized_shape)
- else:
- try:
- normalized_shape = self.env[normalized_shape]
- except TypeError:
- normalized_shape = tuple(normalized_shape)
-
- dim_num = len(normalized_shape)
- axes = list(range(-dim_num, 0))
-
- if gamma is None:
- shape_tuple = [int(s) for s in normalized_shape]
- gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype)
- if beta is None:
- shape_tuple = [int(s) for s in normalized_shape]
- beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype)
-
- return self.block_builder.emit(
- relax.op.nn.layer_norm(
- x,
- gamma,
- beta,
- axes=axes,
- epsilon=eps,
- )
- )
-
- def _layer_norm(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- normalized_shape = node.args[1]
- gamma = self.env[node.args[2]] if len(node.args) > 2 else None
- beta = self.env[node.args[3]] if len(node.args) > 3 else None
- eps = node.args[4] if len(node.args) > 4 else 1e-05
- return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape)
-
- def _layer_norm_module(self, node: fx.Node) -> relax.Var:
- import torch # type: ignore
-
- x = self.env[node.args[0]]
- module = self.named_modules[node.target]
- normalized_shape = module.normalized_shape
- if module.elementwise_affine:
- gamma = self.params[module.weight]
- beta = self.params[module.bias]
- else:
- gamma = relax.const(torch.ones_like(module.normalized_shape),
x.struct_info.dtype)
- beta = relax.const(torch.zeros_like(module.normalized_shape),
x.struct_info.dtype)
- eps = module.eps
- return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape)
-
def _linear_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -728,39 +378,6 @@ class TorchFXImporter(BaseFXGraphImporter):
return self._max_pool2d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
- def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
- transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1,
3])
- query = transpose_S_H(self.env[node.args[0]])
- key = transpose_S_H(self.env[node.args[1]])
- value = transpose_S_H(self.env[node.args[2]])
- attn_mask = node.args[3] if len(node.args) > 3 else
node.kwargs.get("attn_mask", None)
- dropout_p = node.args[4] if len(node.args) > 4 else
node.kwargs.get("dropout_p", 0.0)
- assert dropout_p == 0.0, "Dropout is not supported"
- is_causal = node.args[5] if len(node.args) > 5 else
node.kwargs.get("is_causal", False)
- causal_mask = "TopLeft" if is_causal else None
-
- if attn_mask is not None:
- attn_mask = self.env[attn_mask]
- msg = "Only a float mask is supported for the attn_mask input."
- assert "float" in attn_mask.struct_info.dtype, msg
-
- return self.block_builder.emit(
- transpose_S_H(
- relax.op.nn.attention(query, key, value, bias=attn_mask,
causal_mask=causal_mask)
- )
- )
-
- def _unbind(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
- assert isinstance(dim, int), "Expected 2nd argument of unbind as int"
- selections = self.shape_of(x)[dim].value
- n_section = list(range(1, selections + 1))
- ret, split = [], self.block_builder.emit(relax.op.split(x, n_section,
dim))
- for i in range(selections):
- ret.append(self.block_builder.emit(relax.op.squeeze(split[i],
axis=dim)))
- return self.block_builder.emit(relax.Tuple(ret))
-
########## Manipulation ##########
def _cat(self, node: fx.Node) -> relax.Var:
@@ -1054,87 +671,6 @@ class TorchFXImporter(BaseFXGraphImporter):
return self.shape_of(self.env[node.args[0]])
return getattr(self.env[node.args[0]], node.args[1])
- def _getitem(self, node: fx.Node) -> relax.Var:
- import torch
-
- x = self.env[node.args[0]]
- if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)):
- return x[node.args[1]]
- elif isinstance(x, relax.Var):
- if isinstance(x.struct_info, relax.TupleStructInfo):
- return self.block_builder.emit(relax.TupleGetItem(x,
node.args[1]))
-
- assert isinstance(x.struct_info, relax.TensorStructInfo)
- take_indices = []
- take_axes = []
- stride_begin = []
- stride_end = []
- stride = []
- stride_axes = []
- expand_dim = []
- i = 0
- shape = self.shape_of(x)
- non_ellipsis_cnt = 0
- for index in node.args[1]:
- if isinstance(index, (int, slice, torch.fx.Node)):
- non_ellipsis_cnt += 1
- for index in node.args[1]:
- if isinstance(index, int):
- stride_begin.append(index)
- stride_end.append(index + 1)
- stride.append(1)
- stride_axes.append(i)
- i = i + 1
- elif isinstance(index, slice):
- stride_begin.append(0 if index.start is None else
index.start)
- stride_end.append(shape[i] if index.stop is None else
index.stop)
- stride.append(1 if index.step is None else index.step)
- stride_axes.append(i)
- i = i + 1
- elif index is None:
- expand_dim.append(len(stride_axes) + len(expand_dim))
- elif index is Ellipsis:
- for _ in range(len(shape) - non_ellipsis_cnt):
- stride_begin.append(0)
- stride_end.append(shape[i])
- stride.append(1)
- stride_axes.append(i)
- i += 1
- elif isinstance(index, torch.fx.Node):
- node_index = self.env[index]
- if not isinstance(node_index, relax.Expr):
- raise ValueError(
- "Unsupported index type for relax.op.take: " +
str(type(node_index))
- )
- take_indices.append(node_index)
- take_axes.append(i)
- i = i + 1
- else:
- raise ValueError("Unsupported index type: " +
str(type(index)))
- while i < len(shape):
- stride_begin.append(0)
- stride_end.append(shape[i])
- stride.append(1)
- stride_axes.append(i)
- i += 1
- taken = x
- if len(take_indices) > 1:
- raise ValueError("Multiple tensors as index not yet supported")
- for each_index, each_axis in zip(take_indices, take_axes):
- taken = self.block_builder.emit(relax.op.take(taken,
each_index, each_axis))
- sliced = self.block_builder.emit(
- relax.op.strided_slice(taken, stride_axes, stride_begin,
stride_end, stride)
- )
- sliced_shape = list(self.shape_of(sliced))
- for i in expand_dim:
- sliced_shape.insert(i, 1)
- return self.block_builder.emit(relax.op.reshape(sliced,
sliced_shape))
- elif isinstance(x, relax.Constant):
- dtype = x.struct_info.dtype
- return relax.const(x.data.numpy()[node.args[1]], dtype)
- else:
- assert False
-
def _sym_size_int(self, node: fx.Node) -> relax.Expr:
x = self.env[node.args[0]]
shape = self.shape_of(x)
@@ -1182,8 +718,8 @@ class TorchFXImporter(BaseFXGraphImporter):
nn.Conv1d: self._conv1d_module,
nn.Conv2d: self._conv2d_module,
nn.Conv3d: self._conv3d_module,
- nn.ConvTranspose1d: self._conv1d_transpose_module,
- nn.ConvTranspose2d: self._conv2d_transpose_module,
+ nn.ConvTranspose1d: self._conv_transpose1d_module,
+ nn.ConvTranspose2d: self._conv_transpose2d_module,
nn.CrossEntropyLoss: self._cross_entropy_module,
nn.GroupNorm: self._group_norm_module,
nn.LayerNorm: self._layer_norm_module,
@@ -1248,8 +784,8 @@ class TorchFXImporter(BaseFXGraphImporter):
"bmm": self._binary_op(
partial(relax.op.linear_algebra.matmul, out_dtype="float32"),
operator.matmul
),
- "conv_transpose1d": self._conv1d_transpose,
- "conv_transpose2d": self._conv2d_transpose,
+ "conv_transpose1d": self._conv_transpose1d,
+ "conv_transpose2d": self._conv_transpose2d,
"conv1d": self._conv1d,
"conv2d": self._conv2d,
"conv3d": self._conv3d,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 25e6dbfae3..7c887d9b96 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1156,6 +1156,59 @@ def test_binary():
verify_model(Sub2(), example_args2, {}, expected_sub2)
+def test_batchnorm2d():
+ class BatchNorm2d(Module):
+ def __init__(self):
+ super().__init__()
+ self.bn = torch.nn.BatchNorm2d(3)
+
+ def forward(self, input):
+ return self.bn(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((3,), dtype="float32"),
+ w2: R.Tensor((3,), dtype="float32"),
+ w3: R.Tensor((3,), dtype="float32"),
+ w4: R.Tensor((3,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ ) = R.nn.batch_norm(
+ input_1,
+ w1,
+ w2,
+ w3,
+ w4,
+ axis=1,
+ epsilon=1e-05,
+ center=True,
+ scale=True,
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+ model = BatchNorm2d().eval()
+ binding = {
+ "w1": model.bn.weight.detach().numpy(),
+ "w2": model.bn.bias.detach().numpy(),
+ "w3": model.bn.running_mean.detach().numpy(),
+ "w4": model.bn.running_var.detach().numpy(),
+ }
+ verify_model(model, example_args, binding, expected1)
+
+
def test_adaptive_avgpool2d():
class AdaptiveAvgPool2d0(Module):
def __init__(self):
@@ -1165,28 +1218,594 @@ def test_adaptive_avgpool2d():
def forward(self, input):
return self.pool(input)
- class AdaptiveAvgPool2d1(Module):
+ class AdaptiveAvgPool2d1(Module):
+ def forward(self, input):
+ return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10])
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.adaptive_avg_pool2d(
+ input_1, output_size=[10, 10], layout="NCHW",
out_layout="NCHW"
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1)
+ verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)
+
+
+def test_addmm():
+ class Addmm1(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x1, x2, x3):
+ return torch.addmm(x1, x2, x3)
+
+ class Addmm2(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x1, x2, x3):
+ return torch.addmm(x1, x2, x3, beta=0.8, alpha=0.5)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x1: R.Tensor((10, 10), dtype="float32"),
+ x2: R.Tensor((10, 10), dtype="float32"),
+ x3: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3,
out_dtype="float32")
+ lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ x1: R.Tensor((10, 10), dtype="float32"),
+ x2: R.Tensor((10, 10), dtype="float32"),
+ x3: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3,
out_dtype="float32")
+ lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv,
R.const(0.5, "float32"))
+ lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1,
R.const(0.8, "float32"))
+ lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(10, 10, dtype=torch.float32),
+ torch.randn(10, 10, dtype=torch.float32),
+ torch.randn(10, 10, dtype=torch.float32),
+ )
+
+ verify_model(Addmm1(), example_args, {}, expected1)
+ verify_model(Addmm2(), example_args, {}, expected2)
+
+
+def test_avg_pool2d():
+ class AvgPool2d1(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1])
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.avg_pool2d(
+ input_1,
+ pool_size=[1, 1],
+ strides=[1, 1],
+ dilation=[1, 1],
+ padding=[0, 0, 0, 0],
+ layout="NCHW",
+ out_layout="NCHW",
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class AvgPool2d2(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2,
padding=2, ceil_mode=True)
+
+ def forward(self, input):
+ return self.pool(input)
+
+ class AvgPool2d3(Module):
+ def forward(self, input):
+ return torch.nn.functional.avg_pool2d(
+ input, kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True
+ )
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv = R.nn.avg_pool2d(
+ input_1,
+ pool_size=[4, 4],
+ strides=[2, 2],
+ dilation=[1, 1],
+ padding=[2, 2, 2, 2],
+ ceil_mode=True,
+ layout="NCHW",
+ out_layout="NCHW",
+ )
+ gv = (lv,)
+ R.output(gv)
+ return gv
+
+ class AvgPool2d4(Module):
+ def forward(self, input):
+ return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1],
divisor_override=2)
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv = R.nn.avg_pool2d(
+ input_1,
+ pool_size=[2, 1],
+ strides=[2, 1],
+ dilation=[1, 1],
+ padding=[0, 0, 0, 0],
+ ceil_mode=False,
+ layout="NCHW",
+ out_layout="NCHW",
+ )
+ gv = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(AvgPool2d1(), example_args, {}, expected1)
+ verify_model(AvgPool2d2(), example_args, {}, expected2)
+ verify_model(AvgPool2d3(), example_args, {}, expected2)
+ verify_model(AvgPool2d4(), example_args, {}, expected3)
+
+
+def test_baddbmm():
+ class BAddBMM1(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, c, x, y):
+ return torch.baddbmm(c, x, y)
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((4, 128, 512), dtype="float32"),
+ inp_1: R.Tensor((4, 128, 256), dtype="float32"),
+ inp_2: R.Tensor((4, 256, 512), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1,
inp_2)
+ lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv,
inp_0)
+ gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ class BAddBMM2(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, c, x, y):
+ return torch.baddbmm(c, x, y, alpha=2, beta=0)
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((4, 128, 512), dtype="float32"),
+ inp_1: R.Tensor((4, 128, 256), dtype="float32"),
+ inp_2: R.Tensor((4, 256, 512), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1,
inp_2)
+ lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
+ lv, R.const(2, "float32")
+ )
+ gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ class BAddBMM3(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, c, x, y):
+ return torch.baddbmm(c, x, y, alpha=2, beta=3)
+
+ @tvm.script.ir_module
+ class Expected3:
+ @R.function
+ def main(
+ inp_0: R.Tensor((4, 128, 512), dtype="float32"),
+ inp_1: R.Tensor((4, 128, 256), dtype="float32"),
+ inp_2: R.Tensor((4, 256, 512), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1,
inp_2)
+ lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
+ lv, R.const(2, "float32")
+ )
+ lv2: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
+ inp_0, R.const(3, "float32")
+ )
+ lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv1, lv2)
+ gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(4, 128, 512, dtype=torch.float32),
+ torch.randn(4, 128, 256, dtype=torch.float32),
+ torch.randn(4, 256, 512, dtype=torch.float32),
+ )
+ verify_model(
+ BAddBMM1(),
+ example_args,
+ {},
+ Expected1,
+ )
+
+ verify_model(
+ BAddBMM2(),
+ example_args,
+ {},
+ Expected2,
+ )
+
+ verify_model(
+ BAddBMM3(),
+ example_args,
+ {},
+ Expected3,
+ )
+
+
+def test_bmm():
+ class BMM(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return torch.bmm(x, y)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ input_1: R.Tensor((4, 128, 256), dtype="float32"),
+ input_2: R.Tensor((4, 256, 512), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
+ input_1, input_2, out_dtype="float32"
+ )
+ gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(4, 128, 256, dtype=torch.float32),
+ torch.randn(4, 256, 512, dtype=torch.float32),
+ )
+ verify_model(
+ BMM(),
+ example_args,
+ {},
+ Expected,
+ )
+
+
+def test_conv_transpose1d():
+ class ConvTranspose1d1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=True)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ class ConvTranspose1d1Func(Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = torch.randn(size=[6, 6, 3])
+ self.bias = torch.randn(size=[6])
+
+ def forward(self, input):
+ return torch.nn.functional.conv_transpose1d(input, self.weight,
self.bias)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 6, 4), dtype="float32"),
+ w1: R.Tensor((6, 6, 3), dtype="float32"),
+ w2: R.Tensor((6,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 6), dtype="float32") =
R.nn.conv1d_transpose(
+ input_1,
+ w1,
+ strides=[1],
+ padding=[0, 0],
+ dilation=[1],
+ data_layout="NCW",
+ kernel_layout="OIW",
+ out_layout="NCW",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
+ lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2)
+ gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ class ConvTranspose1d2(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=False)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 6, 4), dtype="float32"),
+ w1: R.Tensor((6, 6, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 6), dtype="float32") =
R.nn.conv1d_transpose(
+ input_1,
+ w1,
+ strides=[1],
+ padding=[0, 0],
+ dilation=[1],
+ data_layout="NCW",
+ kernel_layout="OIW",
+ out_layout="NCW",
+ out_dtype="float32",
+ )
+ gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 6, 4, dtype=torch.float32),)
+
+ model = ConvTranspose1d1()
+ binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+ model = ConvTranspose1d1Func()
+ binding = {"w1": model.weight.detach().numpy(), "w2":
model.bias.detach().numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+ model = ConvTranspose1d2()
+ binding = {"w1": model.conv.weight.detach().numpy()}
+ verify_model(model, example_args, binding, expected2)
+
+
+def test_conv_transpose2d():
+ class ConvTranspose2d1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=True)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ class ConvTranspose2d1Func(Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = torch.randn(size=[3, 3, 7, 7])
+ self.bias = torch.randn(size=[3])
+
+ def forward(self, input):
+ return torch.nn.functional.conv_transpose2d(input, self.weight,
self.bias)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((3, 3, 7, 7), dtype="float32"),
+ w2: R.Tensor((3,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 3, 16, 16), dtype="float32") =
R.nn.conv2d_transpose(
+ input_1,
+ w1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1])
+ lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1,
lv2)
+ gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ class ConvTranspose2d2(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=False)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((3, 3, 7, 7), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 3, 16, 16), dtype="float32") =
R.nn.conv2d_transpose(
+ input_1,
+ w1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+ model = ConvTranspose2d1()
+ binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+ model = ConvTranspose2d1Func()
+ binding = {"w1": model.weight.detach().numpy(), "w2":
model.bias.detach().numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+ model = ConvTranspose2d2()
+ binding = {"w1": model.conv.weight.detach().numpy()}
+ verify_model(model, example_args, binding, expected2)
+
+
+def test_conv1d():
+ class Conv1D1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv1d(3, 6, 7, bias=True)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ class Conv1D1Func(Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = torch.randn(size=[6, 3, 7])
+ self.bias = torch.randn(size=[6])
+
+ def forward(self, input):
+ return torch.nn.functional.conv1d(input, self.weight, self.bias)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ w1: R.Tensor((6, 3, 7), dtype="float32"),
+ w2: R.Tensor((6,), dtype="float32"),
+ input_1: R.Tensor((1, 3, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d(
+ input_1,
+ w1,
+ strides=[1],
+ padding=[0, 0],
+ dilation=[1],
+ data_layout="NCW",
+ kernel_layout="OIW",
+ out_layout="NCW",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1,
6, 1])
+ lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2)
+ gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ class Conv1D2(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)
+
def forward(self, input):
- return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10])
+ return self.conv(input)
@tvm.script.ir_module
- class expected1:
+ class expected2:
@R.function
def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
- ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ w1: R.Tensor((6, 3, 7), dtype="float32"),
+ input_1: R.Tensor((1, 3, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")):
# block 0
with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.adaptive_avg_pool2d(
- input_1, output_size=[10, 10], layout="NCHW",
out_layout="NCHW"
+ lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d(
+ input_1,
+ w1,
+ strides=[1],
+ padding=[0, 0],
+ dilation=[1],
+ data_layout="NCW",
+ kernel_layout="OIW",
+ out_layout="NCW",
+ out_dtype="float32",
)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv1,)
R.output(gv)
return gv
- example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1)
- verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)
+ example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
+
+ model = Conv1D1()
+ binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+ model = Conv1D1Func()
+ binding = {"w1": model.weight.detach().numpy(), "w2":
model.bias.detach().numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+ model = Conv1D2()
+ binding = {"w1": model.conv.weight.detach().numpy()}
+ verify_model(model, example_args, binding, expected2)
def test_conv2d():
@@ -1281,6 +1900,267 @@ def test_conv2d():
verify_model(model, example_args, binding, expected2)
+def test_conv3d():
+ class Conv3D1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv3d(3, 6, 7, bias=True)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ class Conv3D1Func(Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = torch.randn(size=[6, 3, 7, 7, 7])
+ self.bias = torch.randn(size=[6])
+
+ def forward(self, input):
+ return torch.nn.functional.conv3d(input, self.weight, self.bias)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"),
+ w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"),
+ w2: R.Tensor((6,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d(
+ input_1,
+ w1,
+ strides=[1],
+ padding=[0, 0, 0],
+ dilation=[1],
+ data_layout="NCDHW",
+ kernel_layout="OIDHW",
+ out_layout="NCDHW",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1])
+ lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1,
lv2)
+ gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) =
(lv3,)
+ R.output(gv)
+ return gv
+
+ class Conv3D2(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv3d(3, 6, 7, bias=False)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"),
+ w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d(
+ input_1,
+ w1,
+ strides=[1],
+ padding=[0, 0, 0],
+ dilation=[1],
+ data_layout="NCDHW",
+ kernel_layout="OIDHW",
+ out_layout="NCDHW",
+ out_dtype="float32",
+ )
+ gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) =
(lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),)
+
+ model = Conv3D1()
+ binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+ model = Conv3D1Func()
+ binding = {"w1": model.weight.detach().numpy(), "w2":
model.bias.detach().numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+ model = Conv3D2()
+ binding = {"w1": model.conv.weight.detach().numpy()}
+ verify_model(model, example_args, binding, expected2)
+
+
+def test_einsum():
+ class Einsum1(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return torch.einsum("ii", x)
+
+ class Einsum2(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return torch.einsum("i,j->ij", x, y)
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((4, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,),
subscripts="ii")
+ gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((5, 4), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((5, 4), dtype="float32") = R.einsum(
+ (inp_0, inp_1), subscripts="i,j->ij"
+ )
+ gv: R.Tuple(R.Tensor((5, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(4, 4, dtype=torch.float32),)
+ verify_model(Einsum1(), example_args, {}, Expected1)
+
+ example_args = (torch.randn(5, dtype=torch.float32), torch.randn(4,
dtype=torch.float32))
+ verify_model(Einsum2(), example_args, {}, Expected2)
+
+
+def test_embedding():
+ class Embedding(Module):
+ def __init__(self):
+ super().__init__()
+ self.embedding = torch.nn.Embedding(10, 3)
+
+ def forward(self, input):
+ return self.embedding(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((4,), dtype="int32") = R.astype(input_1,
dtype="int32")
+ lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0)
+ gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randint(low=-int(1e5), high=int(1e5), size=(4,),
dtype=torch.int64),)
+
+ model = Embedding()
+ binding = {"w1": model.embedding.weight.detach().numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+
+def test_groupnorm():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ class GroupNorm(Module):
+ def __init__(self):
+ super().__init__()
+ self.gn = torch.nn.GroupNorm(3, 3)
+
+ def forward(self, input):
+ return self.gn(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((3,), dtype="float32"),
+ w2: R.Tensor((3,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.group_norm(
+ input_1,
+ w1,
+ w2,
+ num_groups=3,
+ channel_axis=1,
+ axes=[2, 3],
+ epsilon=1.0000000000000001e-05,
+ center=True,
+ scale=True,
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+ model = GroupNorm()
+ binding = {
+ "w1": model.gn.weight.detach().numpy(),
+ "w2": model.gn.bias.detach().numpy(),
+ }
+ verify_model(model, example_args, binding, expected1)
+
+
+def test_layernorm():
+ class LayerNorm(Module):
+ def __init__(self):
+ super().__init__()
+ self.ln = torch.nn.LayerNorm((10, 10))
+
+ def forward(self, input):
+ return self.ln(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((10, 10), dtype="float32"),
+ w2: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.layer_norm(
+ input_1,
+ w1,
+ w2,
+ axes=[-2, -1],
+ epsilon=1e-05,
+ center=True,
+ scale=True,
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+ model = LayerNorm()
+ binding = {
+ "w1": model.ln.weight.detach().numpy(),
+ "w2": model.ln.bias.detach().numpy(),
+ }
+ verify_model(LayerNorm(), example_args, binding, expected1)
+
+
def test_linear():
class Dense1(Module):
def __init__(self):
@@ -1460,6 +2340,254 @@ def test_maxpool2d():
verify_model(MaxPool2d3(), example_args, {}, expected3)
+def test_scaled_dot_product_attention():
+ class Attention1(Module):
+ def forward(self, q, k, v):
+ return torch.nn.functional.scaled_dot_product_attention(q, k, v)
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ inp_0, axes=[0, 2, 1, 3]
+ )
+ lv1: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ inp_1, axes=[0, 2, 1, 3]
+ )
+ lv2: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ inp_2, axes=[0, 2, 1, 3]
+ )
+ lv3: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.nn.attention(
+ lv, lv1, lv2, scale=None
+ )
+ lv4: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.permute_dims(
+ lv3, axes=[0, 2, 1, 3]
+ )
+ gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) =
(lv4,)
+ R.output(gv)
+ return gv
+
+ class Attention2(Module):
+ def forward(self, q, k, v, mask):
+ return torch.nn.functional.scaled_dot_product_attention(q, k, v,
mask)
+
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ inp_0, axes=[0, 2, 1, 3]
+ )
+ lv1: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ inp_1, axes=[0, 2, 1, 3]
+ )
+ lv2: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
+ inp_2, axes=[0, 2, 1, 3]
+ )
+ lv3: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.nn.attention(
+ lv, lv1, lv2, inp_3, scale=None
+ )
+ lv4: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.permute_dims(
+ lv3, axes=[0, 2, 1, 3]
+ )
+ gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) =
(lv4,)
+ R.output(gv)
+ return gv
+
+ verify_model(
+ Attention1(),
+ (
+ torch.randn(32, 8, 128, 64, dtype=torch.float32),
+ torch.randn(32, 8, 128, 64, dtype=torch.float32),
+ torch.randn(32, 8, 128, 64, dtype=torch.float32),
+ ),
+ {},
+ Expected1,
+ )
+
+ verify_model(
+ Attention2(),
+ (
+ torch.randn(32, 8, 128, 64, dtype=torch.float32),
+ torch.randn(32, 8, 128, 64, dtype=torch.float32),
+ torch.randn(32, 8, 128, 64, dtype=torch.float32),
+ torch.randn(32, 8, 128, 128, dtype=torch.float32),
+ ),
+ {},
+ Expected2,
+ )
+
+
+def test_unbind():
+ class Unbind1(Module):
+ def forward(self, data):
+ return torch.unbind(data)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ ):
+ # block 0
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((0, 3, 10, 10), dtype="float32"),
+ ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0)
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
+ lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1,
axis=[0])
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1]
+ lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3,
axis=[0])
+ lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2]
+ lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5,
axis=[0])
+ lv7: R.Tuple(
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ ) = (lv2, lv4, lv6)
+ lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
+ lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
+ lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+ gv: R.Tuple(
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ ) = (lv8, lv9, lv10)
+ R.output(gv)
+ return gv
+
+ class Unbind2(Module):
+ def forward(self, data):
+ return torch.unbind(data, dim=1)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ ):
+ # block 0
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((3, 1, 10, 10), dtype="float32"),
+ R.Tensor((3, 1, 10, 10), dtype="float32"),
+ R.Tensor((3, 1, 10, 10), dtype="float32"),
+ R.Tensor((3, 0, 10, 10), dtype="float32"),
+ ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1)
+ lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0]
+ lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1,
axis=[1])
+ lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1]
+ lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3,
axis=[1])
+ lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2]
+ lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5,
axis=[1])
+ lv7: R.Tuple(
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ ) = (lv2, lv4, lv6)
+ lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
+ lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
+ lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+ gv: R.Tuple(
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ R.Tensor((3, 10, 10), dtype="float32"),
+ ) = (lv8, lv9, lv10)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
+ verify_model(Unbind1(), example_args, {}, expected1)
+ verify_model(Unbind2(), example_args, {}, expected2)
+
+
+def test_interpolate():
+ class InterpolateBilinear(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(input, (224, 224),
mode="bilinear")
+
+ @tvm.script.ir_module
+ class expected_bilinear:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 112, 112), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 224, 224), dtype="float32") =
R.image.resize2d(
+ input,
+ R.shape([224, 224]),
+ roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0),
T.float32(0.0)],
+ layout="NCHW",
+ method="linear",
+ coordinate_transformation_mode="half_pixel",
+ rounding_method="round",
+ cubic_alpha=-0.5,
+ cubic_exclude=0,
+ extrapolation_value=0.0,
+ out_dtype="void",
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) =
(lv,)
+ R.output(gv)
+ return gv
+
+ class InterpolateNearest(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(input, (224, 224),
mode="nearest")
+
+ @tvm.script.ir_module
+ class expected_nearest:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 112, 112), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 224, 224), dtype="float32") =
R.image.resize2d(
+ input,
+ R.shape([224, 224]),
+ roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0),
T.float32(0.0)],
+ layout="NCHW",
+ method="nearest_neighbor",
+ coordinate_transformation_mode="half_pixel",
+ rounding_method="round",
+ cubic_alpha=-0.5,
+ cubic_exclude=0,
+ extrapolation_value=0.0,
+ out_dtype="void",
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) =
(lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),)
+ verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear)
+ verify_model(InterpolateNearest(), example_args, {}, expected_nearest)
+
+
def test_mean():
class Mean(Module):
def forward(self, input):