This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d9ee6377cd [Relax][PyTorch] Support neural network ops for 
ExportedProgram importer (#17426)
d9ee6377cd is described below

commit d9ee6377cdd8395b27385d2fc2745b741fad6183
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Sun Sep 29 06:59:33 2024 +0900

    [Relax][PyTorch] Support neural network ops for ExportedProgram importer 
(#17426)
    
    * support batchnorm2d and getitem
    
    * support addmm
    
    * support avg_pool2d
    
    * support baddbmm
    
    * support bmm
    
    * support conv_transpose1d
    
    * support conv_transpose2d
    
    * support conv1d
    
    * support conv3d
    
    * support einsum
    
    * support embedding
    
    * support group_norm
    
    * support layer_norm
    
    * support scaled_dot_product_attention
    
    * support unbind
    
    * support interpolate
    
    * fix lint error
---
 .../frontend/torch/base_fx_graph_translator.py     |  464 ++++++++
 .../frontend/torch/exported_program_translator.py  |  111 ++
 python/tvm/relax/frontend/torch/fx_translator.py   |  482 +-------
 .../relax/test_frontend_from_exported_program.py   | 1150 +++++++++++++++++++-
 4 files changed, 1723 insertions(+), 484 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index a41b9b6d4f..52784dc8c3 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -227,6 +227,228 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
         )
 
+    def _addmm(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        y = self.env[node.args[1]]
+        z = self.env[node.args[2]]
+        alpha = node.kwargs.get("alpha", 1)
+        beta = node.kwargs.get("beta", 1)
+
+        res = None
+        if alpha != 0:
+            res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, 
out_dtype="float32"))
+            if alpha != 1:
+                dtype = res.struct_info.dtype
+                res = self.block_builder.emit(relax.op.multiply(res, 
relax.const(alpha, dtype)))
+        if beta != 0:
+            dtype = x.struct_info.dtype
+            if beta != 1:
+                bias = self.block_builder.emit(relax.op.multiply(x, 
relax.const(beta, dtype)))
+            else:
+                bias = x
+            res = bias if res is None else 
self.block_builder.emit(relax.op.add(bias, res))
+        return res
+
+    def _avg_pool2d_impl(
+        self,
+        x: relax.Expr,
+        kernel_size: Union[int, Tuple[int, int]] = (1, 1),
+        stride: Optional[Union[int, Tuple[int, int]]] = None,
+        padding: Optional[int] = 0,
+        ceil_mode: Optional[bool] = False,
+    ) -> relax.Var:
+        stride = kernel_size if stride is None or stride == [] else stride
+        return self.block_builder.emit(
+            relax.op.nn.avg_pool2d(
+                x,
+                pool_size=kernel_size,
+                strides=stride,
+                padding=padding,
+                ceil_mode=ceil_mode,
+                layout="NCHW",
+            )
+        )
+
+    def _avg_pool2d(self, node: fx.Node) -> relax.Var:
+        args, kwargs = node.normalized_arguments(node)
+        x = self.env[args[0]]
+        kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"]
+        stride = args[2] if len(args) > 2 else kwargs.get("stride", None)
+        padding = args[3] if len(args) > 3 else kwargs.get("padding", 0)
+        ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", 
False)
+        return self._avg_pool2d_impl(x, kernel_size, stride, padding, 
ceil_mode)
+
+    def _baddbmm(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        batch1 = self.env[node.args[1]]
+        batch2 = self.env[node.args[2]]
+        alpha = node.kwargs.get("alpha", 1)
+        beta = node.kwargs.get("beta", 1)
+
+        res = None
+        if alpha != 0:
+            res = self.block_builder.emit(relax.op.matmul(batch1, batch2))
+            if alpha != 1:
+                dtype = res.struct_info.dtype
+                res = self.block_builder.emit(relax.op.multiply(res, 
relax.const(alpha, dtype)))
+        if beta != 0:
+            dtype = x.struct_info.dtype
+            if beta != 1:
+                bias = self.block_builder.emit(relax.op.multiply(x, 
relax.const(beta, dtype)))
+            else:
+                bias = x
+            res = bias if res is None else 
self.block_builder.emit(relax.op.add(res, bias))
+        return res
+
+    def _conv_transpose1d_impl(
+        self,
+        x: relax.Expr,
+        weight: relax.Expr,
+        bias: Optional[relax.Expr],
+        strides: Optional[Tuple],
+        padding: Optional[Tuple],
+        dilation: Optional[Tuple],
+        groups: Optional[Tuple],
+    ) -> relax.Var:
+        conv1d_transpose = self.block_builder.emit(
+            relax.op.nn.conv1d_transpose(
+                x,
+                weight,
+                strides=strides,
+                padding=padding,
+                dilation=dilation,
+                groups=groups,
+                data_layout="NCW",
+                kernel_layout="OIW",
+                out_dtype="float32",
+            )
+        )
+
+        if bias is None:
+            return conv1d_transpose
+
+        assert len(self.shape_of(bias)) == 1
+        bias = relax.op.reshape(bias, (1, -1, 1))
+        return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))
+
+    def _conv_transpose1d(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        weight = args[1]
+        bias = args[2] if len(args) > 2 else None
+        stride = args[3] if len(args) > 3 else 1
+        padding = args[4] if len(args) > 4 else 0
+        dilation = args[5] if len(args) > 5 else 1
+        groups = args[6] if len(args) > 6 else 1
+        return self._conv_transpose1d_impl(
+            x,
+            weight,
+            bias=bias,
+            strides=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+        )
+
+    def _conv_transpose2d_impl(
+        self,
+        x: relax.Expr,
+        weight: relax.Expr,
+        bias: Optional[relax.Expr],
+        strides: Optional[Tuple],
+        padding: Optional[Tuple],
+        dilation: Optional[Tuple],
+        groups: Optional[Tuple],
+    ) -> relax.Var:
+        conv2d_transpose = self.block_builder.emit(
+            relax.op.nn.conv2d_transpose(
+                x,
+                weight,
+                strides=strides,
+                padding=padding,
+                dilation=dilation,
+                groups=groups,
+                data_layout="NCHW",
+                kernel_layout="OIHW",
+                out_dtype="float32",
+            )
+        )
+
+        if bias is None:
+            return conv2d_transpose
+
+        assert len(self.shape_of(bias)) == 1
+        bias = relax.op.reshape(bias, (1, -1, 1, 1))
+        return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))
+
+    def _conv_transpose2d(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        weight = args[1]
+        bias = args[2] if len(args) > 2 else None
+        stride = args[3] if len(args) > 3 else 1
+        padding = args[4] if len(args) > 4 else 0
+        dilation = args[5] if len(args) > 5 else 1
+        groups = args[6] if len(args) > 6 else 1
+        return self._conv_transpose2d_impl(
+            x,
+            weight,
+            bias=bias,
+            strides=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+        )
+
+    def _conv1d_impl(
+        self,
+        x: relax.Expr,
+        weight: relax.Expr,
+        bias: Optional[relax.Expr],
+        strides: Optional[Tuple],
+        padding: Optional[Tuple],
+        dilation: Optional[Tuple],
+        groups: Optional[Tuple],
+    ) -> relax.Var:
+        conv1d = self.block_builder.emit(
+            relax.op.nn.conv1d(
+                x,
+                weight,
+                strides=strides,
+                padding=padding,
+                dilation=dilation,
+                groups=groups,
+                data_layout="NCW",
+                kernel_layout="OIW",
+                out_dtype="float32",
+            )
+        )
+
+        if bias is None:
+            return conv1d
+        assert len(self.shape_of(bias)) == 1
+        bias = relax.op.reshape(bias, (1, -1, 1))
+        return self.block_builder.emit(relax.op.add(conv1d, bias))
+
+    def _conv1d(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        weight = args[1]
+        bias = args[2] if len(args) > 2 else None
+        stride = args[3] if len(args) > 3 else 1
+        padding = args[4] if len(args) > 4 else 0
+        dilation = args[5] if len(args) > 5 else 1
+        groups = args[6] if len(args) > 6 else 1
+        return self._conv1d_impl(
+            x,
+            weight,
+            bias=bias,
+            strides=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+        )
+
     def _conv2d_impl(
         self,
         x: relax.Expr,
@@ -276,6 +498,134 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             groups=groups,
         )
 
+    def _conv3d_impl(
+        self,
+        x: relax.Expr,
+        weight: relax.Expr,
+        bias: Optional[relax.Expr],
+        strides: Optional[Tuple],
+        padding: Optional[Tuple],
+        dilation: Optional[Tuple],
+        groups: Optional[Tuple],
+    ):
+        conv3d = self.block_builder.emit(
+            relax.op.nn.conv3d(
+                x,
+                weight,
+                strides=strides,
+                padding=padding,
+                dilation=dilation,
+                groups=groups,
+                data_layout="NCDHW",
+                kernel_layout="OIDHW",
+                out_dtype="float32",
+            )
+        )
+
+        if bias is None:
+            return conv3d
+        assert len(self.shape_of(bias)) == 1
+        bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
+        return self.block_builder.emit(relax.op.add(conv3d, bias))
+
+    def _conv3d(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        weight = args[1]
+        bias = args[2] if len(args) > 2 else None
+        stride = args[3] if len(args) > 3 else 1
+        padding = args[4] if len(args) > 4 else 0
+        dilation = args[5] if len(args) > 5 else 1
+        groups = args[6] if len(args) > 6 else 1
+        return self._conv3d_impl(
+            x,
+            weight,
+            bias=bias,
+            strides=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+        )
+
+    def _einsum(self, node: fx.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        args = self.retrieve_args(node)
+        operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) 
else args[1:]
+        return self.block_builder.emit(relax.op.einsum(operands, args[0]))
+
+    def _embedding_impl(
+        self,
+        x,
+        weight,
+    ) -> relax.Var:
+        x = self.block_builder.emit(relax.op.astype(x, "int32"))
+
+        ndim = x.struct_info.ndim
+        if ndim == 1:
+            return self.block_builder.emit(relax.op.take(weight, x, axis=0))
+        else:
+            x_shape = x.struct_info.shape.values
+            emb_size = weight.struct_info.shape.values[-1]
+            x = self.block_builder.emit(relax.op.reshape(x, shape=[-1]))
+            embedding = self.block_builder.emit(relax.op.take(weight, x, 
axis=0))
+            return self.block_builder.emit(relax.op.reshape(embedding, 
[*x_shape, emb_size]))
+
+    def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> 
relax.Var:
+        from torch.fx.immutable_collections import immutable_list
+        import numpy as np  # type: ignore
+
+        if isinstance(normalized_shape, (immutable_list, tuple)):
+            normalized_shape = tuple(normalized_shape)
+        else:
+            try:
+                normalized_shape = self.env[normalized_shape]
+            except TypeError:
+                normalized_shape = tuple(normalized_shape)
+
+        dim_num = len(normalized_shape)
+        axes = list(range(-dim_num, 0))
+
+        if gamma is None:
+            shape_tuple = [int(s) for s in normalized_shape]
+            gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype)
+        if beta is None:
+            shape_tuple = [int(s) for s in normalized_shape]
+            beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype)
+
+        return self.block_builder.emit(
+            relax.op.nn.layer_norm(
+                x,
+                gamma,
+                beta,
+                axes=axes,
+                epsilon=eps,
+            )
+        )
+
+    def _layer_norm(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        normalized_shape = node.args[1]
+        gamma = self.env[node.args[2]] if len(node.args) > 2 else None
+        beta = self.env[node.args[3]] if len(node.args) > 3 else None
+        eps = node.args[4] if len(node.args) > 4 else 1e-05
+        return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape)
+
+    def _layer_norm_module(self, node: fx.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        normalized_shape = module.normalized_shape
+        if module.elementwise_affine:
+            gamma = self.params[module.weight]
+            beta = self.params[module.bias]
+        else:
+            gamma = relax.const(torch.ones_like(module.normalized_shape), 
x.struct_info.dtype)
+            beta = relax.const(torch.zeros_like(module.normalized_shape), 
x.struct_info.dtype)
+        eps = module.eps
+        return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape)
+
     def _linear(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
@@ -316,6 +666,39 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
         return self._max_pool2d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
 
+    def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
+        transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 
3])
+        query = transpose_S_H(self.env[node.args[0]])
+        key = transpose_S_H(self.env[node.args[1]])
+        value = transpose_S_H(self.env[node.args[2]])
+        attn_mask = node.args[3] if len(node.args) > 3 else 
node.kwargs.get("attn_mask", None)
+        dropout_p = node.args[4] if len(node.args) > 4 else 
node.kwargs.get("dropout_p", 0.0)
+        assert dropout_p == 0.0, "Dropout is not supported"
+        is_causal = node.args[5] if len(node.args) > 5 else 
node.kwargs.get("is_causal", False)
+        causal_mask = "TopLeft" if is_causal else None
+
+        if attn_mask is not None:
+            attn_mask = self.env[attn_mask]
+            msg = "Only a float mask is supported for the attn_mask input."
+            assert "float" in attn_mask.struct_info.dtype, msg
+
+        return self.block_builder.emit(
+            transpose_S_H(
+                relax.op.nn.attention(query, key, value, bias=attn_mask, 
causal_mask=causal_mask)
+            )
+        )
+
+    def _unbind(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
+        assert isinstance(dim, int), "Expected 2nd argument of unbind as int"
+        selections = self.shape_of(x)[dim].value
+        n_section = list(range(1, selections + 1))
+        ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, 
dim))
+        for i in range(selections):
+            ret.append(self.block_builder.emit(relax.op.squeeze(split[i], 
axis=dim)))
+        return self.block_builder.emit(relax.Tuple(ret))
+
     ########## Statistical ##########
 
     def _mean(self, node: fx.Node) -> relax.Var:
@@ -357,6 +740,87 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
 
     ########## Others ##########
 
+    def _getitem(self, node: fx.Node) -> relax.Var:
+        import torch
+
+        x = self.env[node.args[0]]
+        if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)):
+            return x[node.args[1]]
+        elif isinstance(x, relax.Var):
+            if isinstance(x.struct_info, relax.TupleStructInfo):
+                return self.block_builder.emit(relax.TupleGetItem(x, 
node.args[1]))
+
+            assert isinstance(x.struct_info, relax.TensorStructInfo)
+            take_indices = []
+            take_axes = []
+            stride_begin = []
+            stride_end = []
+            stride = []
+            stride_axes = []
+            expand_dim = []
+            i = 0
+            shape = self.shape_of(x)
+            non_ellipsis_cnt = 0
+            for index in node.args[1]:
+                if isinstance(index, (int, slice, torch.fx.Node)):
+                    non_ellipsis_cnt += 1
+            for index in node.args[1]:
+                if isinstance(index, int):
+                    stride_begin.append(index)
+                    stride_end.append(index + 1)
+                    stride.append(1)
+                    stride_axes.append(i)
+                    i = i + 1
+                elif isinstance(index, slice):
+                    stride_begin.append(0 if index.start is None else 
index.start)
+                    stride_end.append(shape[i] if index.stop is None else 
index.stop)
+                    stride.append(1 if index.step is None else index.step)
+                    stride_axes.append(i)
+                    i = i + 1
+                elif index is None:
+                    expand_dim.append(len(stride_axes) + len(expand_dim))
+                elif index is Ellipsis:
+                    for _ in range(len(shape) - non_ellipsis_cnt):
+                        stride_begin.append(0)
+                        stride_end.append(shape[i])
+                        stride.append(1)
+                        stride_axes.append(i)
+                        i += 1
+                elif isinstance(index, torch.fx.Node):
+                    node_index = self.env[index]
+                    if not isinstance(node_index, relax.Expr):
+                        raise ValueError(
+                            "Unsupported index type for relax.op.take: " + 
str(type(node_index))
+                        )
+                    take_indices.append(node_index)
+                    take_axes.append(i)
+                    i = i + 1
+                else:
+                    raise ValueError("Unsupported index type: " + 
str(type(index)))
+            while i < len(shape):
+                stride_begin.append(0)
+                stride_end.append(shape[i])
+                stride.append(1)
+                stride_axes.append(i)
+                i += 1
+            taken = x
+            if len(take_indices) > 1:
+                raise ValueError("Multiple tensors as index not yet supported")
+            for each_index, each_axis in zip(take_indices, take_axes):
+                taken = self.block_builder.emit(relax.op.take(taken, 
each_index, each_axis))
+            sliced = self.block_builder.emit(
+                relax.op.strided_slice(taken, stride_axes, stride_begin, 
stride_end, stride)
+            )
+            sliced_shape = list(self.shape_of(sliced))
+            for i in expand_dim:
+                sliced_shape.insert(i, 1)
+            return self.block_builder.emit(relax.op.reshape(sliced, 
sliced_shape))
+        elif isinstance(x, relax.Constant):
+            dtype = x.struct_info.dtype
+            return relax.const(x.data.numpy()[node.args[1]], dtype)
+        else:
+            assert False
+
     @abc.abstractmethod
     def create_convert_map(
         self,
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 11594690cd..64583d7509 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -74,6 +74,94 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         max_val = node.args[2] if len(args) > 2 else node.kwargs("max_val", 
1.0)
         return self.block_builder.emit(relax.op.clip(x, min_val, max_val))
 
+    ########## Neural Network ##########
+
+    def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
+        import numpy as np
+
+        x = self.env[node.args[0]]
+        channel = int(self.shape_of(x)[1])
+        dtype = x.struct_info.dtype
+        weight = self.env.get(node.args[1], relax.const(np.ones(channel), 
dtype=dtype))
+        bias = self.env.get(node.args[2], relax.const(np.zeros(channel), 
dtype=dtype))
+        running_mean = self.env.get(node.args[3], 
relax.const(np.zeros(channel), dtype=dtype))
+        running_var = self.env.get(node.args[4], relax.const(np.ones(channel), 
dtype=dtype))
+        momentum = node.args[5] if len(node.args) > 5 else 
node.kwargs.get("momentum", 0.1)
+        eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps", 
1e-05)
+
+        return self.block_builder.emit(
+            relax.op.nn.batch_norm(
+                x,
+                weight,
+                bias,
+                running_mean,
+                running_var,
+                axis=1,
+                epsilon=eps,
+                momentum=momentum,
+            )
+        )
+
+    def _group_norm(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        num_groups = node.args[1]
+        gamma = self.env[node.args[2]] if len(node.args) > 2 else None
+        beta = self.env[node.args[3]] if len(node.args) > 3 else None
+        eps = node.args[4] if len(node.args) > 4 else 1e-05
+
+        dim = len(self.shape_of(x))
+        return self.block_builder.emit(
+            relax.op.nn.group_norm(
+                x,
+                gamma,
+                beta,
+                num_groups=num_groups,
+                channel_axis=1,
+                axes=list(range(2, dim)),
+                epsilon=eps,
+            )
+        )
+
+    def _upsample_impl(
+        self, x: relax.Expr, size, align_corners: bool, scale_factor, method: 
str
+    ) -> relax.Var:
+        coord_trans = "align_corners" if align_corners else "half_pixel"
+
+        if size is None:
+            shape = self.shape_of(x)
+            assert isinstance(shape, relax.ShapeExpr)
+            if isinstance(scale_factor, (tuple, list)):
+                assert len(scale_factor) == len(shape) - 2
+                size = tuple(
+                    int(shape[i].value * scale_factor[i - 2]) for i in 
range(2, len(shape))
+                )
+            else:
+                size = tuple(int(shape[i].value * scale_factor) for i in 
range(2, len(shape)))
+
+        return self.block_builder.emit(
+            relax.op.image.resize2d(
+                x, size, layout="NCHW", method=method, 
coordinate_transformation_mode=coord_trans
+            )
+        )
+
+    def _upsample_bilinear2d(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", 
None)
+        align_corners = (
+            node.args[2] if len(node.args) > 2 else 
node.kwargs.get("align_corners", True)
+        )
+        scale_factor = node.args[3] if len(node.args) > 3 else 
node.kwargs.get("scale_factor", None)
+        return self._upsample_impl(x, size, align_corners, scale_factor, 
"linear")
+
+    def _upsample_nearest2d(self, node: fx.node) -> relax.Var:
+        x = self.env[node.args[0]]
+        size = node.args[1] if len(node.args) > 1 else node.kwargs.get("size", 
None)
+        align_corners = (
+            node.args[2] if len(node.args) > 2 else 
node.kwargs.get("align_corners", True)
+        )
+        scale_factor = node.args[3] if len(node.args) > 3 else 
node.kwargs.get("scale_factor", None)
+        return self._upsample_impl(x, size, align_corners, scale_factor, 
"nearest_neighbor")
+
     def create_convert_map(
         self,
     ) -> Dict[str, Callable[[fx.Node], relax.Var]]:
@@ -129,10 +217,31 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "pow.Tensor_Tensor": self._binary_op(relax.op.power, operator.pow),
             "sub.Tensor": self._binary_op(relax.op.subtract, operator.sub),
             # neural network
+            "_native_batch_norm_legit_no_training.default": 
self._batch_norm_legit_no_training,
             "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
+            "addmm.default": self._addmm,
+            "avg_pool2d.default": self._avg_pool2d,
+            "baddbmm.default": self._baddbmm,
+            "bmm.default": self._binary_op(
+                partial(relax.op.linear_algebra.matmul, out_dtype="float32"), 
operator.matmul
+            ),
+            "conv_transpose1d.default": self._conv_transpose1d,
+            "conv_transpose2d.input": self._conv_transpose2d,
+            "conv1d.default": self._conv1d,
             "conv2d.default": self._conv2d,
+            "conv3d.default": self._conv3d,
+            "einsum.default": self._einsum,
+            "embedding.default": lambda node: self._embedding_impl(
+                self.env[node.args[1]], self.env[node.args[0]]
+            ),
+            "group_norm.default": self._group_norm,
+            "layer_norm.default": self._layer_norm,
             "linear.default": self._linear,
             "max_pool2d.default": self._max_pool2d,
+            "scaled_dot_product_attention.default": 
self._scaled_dot_product_attention,
+            "unbind.int": self._unbind,
+            "upsample_bilinear2d.vec": self._upsample_bilinear2d,
+            "upsample_nearest2d.vec": self._upsample_nearest2d,
             # statistical
             "mean.dim": self._mean,
             "sum.dim_IntList": self._sum,
@@ -141,6 +250,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "argmin.default": self._argmax_argmin(relax.op.argmin),
             # tensor manipulation
             "view.default": self._reshape,
+            # other
+            "getitem": self._getitem,
         }
 
     def from_exported_program(
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index dc6ebc2eb3..c60c7c3953 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -18,7 +18,7 @@
 # pylint: disable=invalid-name, inconsistent-return-statements, 
unidiomatic-typecheck
 # pylint: disable=import-outside-toplevel
 """PyTorch FX frontend of Relax."""
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Tuple, Union
 from functools import partial, reduce
 
 import tvm
@@ -107,57 +107,6 @@ class TorchFXImporter(BaseFXGraphImporter):
             relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
         )
 
-    def _addmm(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        y = self.env[node.args[1]]
-        z = self.env[node.args[2]]
-        alpha = node.kwargs.get("alpha", 1)
-        beta = node.kwargs.get("beta", 1)
-
-        res = None
-        if alpha != 0:
-            res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, 
out_dtype="float32"))
-            if alpha != 1:
-                dtype = res.struct_info.dtype
-                res = self.block_builder.emit(relax.op.multiply(res, 
relax.const(alpha, dtype)))
-        if beta != 0:
-            dtype = x.struct_info.dtype
-            if beta != 1:
-                bias = self.block_builder.emit(relax.op.multiply(x, 
relax.const(beta, dtype)))
-            else:
-                bias = x
-            res = bias if res is None else 
self.block_builder.emit(relax.op.add(bias, res))
-        return res
-
-    def _avg_pool2d_impl(
-        self,
-        x: relax.Expr,
-        kernel_size: Union[int, Tuple[int, int]] = (1, 1),
-        stride: Optional[Union[int, Tuple[int, int]]] = None,
-        padding: Optional[int] = 0,
-        ceil_mode: Optional[bool] = False,
-    ) -> relax.Var:
-        stride = kernel_size if stride is None or stride == [] else stride
-        return self.block_builder.emit(
-            relax.op.nn.avg_pool2d(
-                x,
-                pool_size=kernel_size,
-                strides=stride,
-                padding=padding,
-                ceil_mode=ceil_mode,
-                layout="NCHW",
-            )
-        )
-
-    def _avg_pool2d(self, node: fx.Node) -> relax.Var:
-        args, kwargs = node.normalized_arguments(node)
-        x = self.env[args[0]]
-        kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"]
-        stride = args[2] if len(args) > 2 else kwargs.get("stride", None)
-        padding = args[3] if len(args) > 3 else kwargs.get("padding", 0)
-        ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", 
False)
-        return self._avg_pool2d_impl(x, kernel_size, stride, padding, 
ceil_mode)
-
     def _avg_pool2d_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -167,28 +116,6 @@ class TorchFXImporter(BaseFXGraphImporter):
         ceil_mode = module.ceil_mode
         return self._avg_pool2d_impl(x, kernel_size, stride, padding, 
ceil_mode)
 
-    def _baddbmm(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        a = self.env[node.args[1]]
-        b = self.env[node.args[2]]
-        alpha = node.kwargs.get("alpha", 1)
-        beta = node.kwargs.get("beta", 1)
-
-        res = None
-        if alpha != 0:
-            res = self.block_builder.emit(relax.op.matmul(a, b))
-            if alpha != 1:
-                dtype = res.struct_info.dtype
-                res = self.block_builder.emit(relax.op.multiply(res, 
relax.const(alpha, dtype)))
-        if beta != 0:
-            dtype = x.struct_info.dtype
-            if beta != 1:
-                bias = self.block_builder.emit(relax.op.multiply(x, 
relax.const(beta, dtype)))
-            else:
-                bias = x
-            res = bias if res is None else 
self.block_builder.emit(relax.op.add(res, bias))
-        return res
-
     def _batch_norm_2d_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -212,63 +139,13 @@ class TorchFXImporter(BaseFXGraphImporter):
 
         return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0))
 
-    def _conv1d_transpose_impl(
-        self,
-        x: relax.Expr,
-        weight: relax.Expr,
-        bias: Optional[relax.Expr],
-        strides: Optional[Tuple],
-        padding: Optional[Tuple],
-        dilation: Optional[Tuple],
-        groups: Optional[Tuple],
-    ) -> relax.Var:
-        conv1d_transpose = self.block_builder.emit(
-            relax.op.nn.conv1d_transpose(
-                x,
-                weight,
-                strides=strides,
-                padding=padding,
-                dilation=dilation,
-                groups=groups,
-                data_layout="NCW",
-                kernel_layout="OIW",
-                out_dtype="float32",
-            )
-        )
-
-        if bias is None:
-            return conv1d_transpose
-
-        assert len(self.shape_of(bias)) == 1
-        bias = relax.op.reshape(bias, (1, -1, 1))
-        return self.block_builder.emit(relax.op.add(conv1d_transpose, bias))
-
-    def _conv1d_transpose(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        weight = args[1]
-        bias = args[2] if len(args) > 2 else None
-        stride = args[3] if len(args) > 3 else 1
-        padding = args[4] if len(args) > 4 else 0
-        dilation = args[5] if len(args) > 5 else 1
-        groups = args[6] if len(args) > 6 else 1
-        return self._conv1d_transpose_impl(
-            x,
-            weight,
-            bias=bias,
-            strides=stride,
-            padding=padding,
-            dilation=dilation,
-            groups=groups,
-        )
-
-    def _conv1d_transpose_module(self, node: fx.Node) -> relax.Var:
+    def _conv_transpose1d_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
         weight = self.params[module.weight]
         bias = self.params.get(module.bias, None)
 
-        return self._conv1d_transpose_impl(
+        return self._conv_transpose1d_impl(
             x,
             weight,
             bias=bias,
@@ -278,63 +155,13 @@ class TorchFXImporter(BaseFXGraphImporter):
             groups=module.groups,
         )
 
-    def _conv2d_transpose_impl(
-        self,
-        x: relax.Expr,
-        weight: relax.Expr,
-        bias: Optional[relax.Expr],
-        strides: Optional[Tuple],
-        padding: Optional[Tuple],
-        dilation: Optional[Tuple],
-        groups: Optional[Tuple],
-    ) -> relax.Var:
-        conv2d_transpose = self.block_builder.emit(
-            relax.op.nn.conv2d_transpose(
-                x,
-                weight,
-                strides=strides,
-                padding=padding,
-                dilation=dilation,
-                groups=groups,
-                data_layout="NCHW",
-                kernel_layout="OIHW",
-                out_dtype="float32",
-            )
-        )
-
-        if bias is None:
-            return conv2d_transpose
-
-        assert len(self.shape_of(bias)) == 1
-        bias = relax.op.reshape(bias, (1, -1, 1, 1))
-        return self.block_builder.emit(relax.op.add(conv2d_transpose, bias))
-
-    def _conv2d_transpose(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        weight = args[1]
-        bias = args[2] if len(args) > 2 else None
-        stride = args[3] if len(args) > 3 else 1
-        padding = args[4] if len(args) > 4 else 0
-        dilation = args[5] if len(args) > 5 else 1
-        groups = args[6] if len(args) > 6 else 1
-        return self._conv2d_transpose_impl(
-            x,
-            weight,
-            bias=bias,
-            strides=stride,
-            padding=padding,
-            dilation=dilation,
-            groups=groups,
-        )
-
-    def _conv2d_transpose_module(self, node: fx.Node) -> relax.Var:
+    def _conv_transpose2d_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
         weight = self.params[module.weight]
         bias = self.params.get(module.bias, None)
 
-        return self._conv2d_transpose_impl(
+        return self._conv_transpose2d_impl(
             x,
             weight,
             bias=bias,
@@ -344,55 +171,6 @@ class TorchFXImporter(BaseFXGraphImporter):
             groups=module.groups,
         )
 
-    def _conv1d_impl(
-        self,
-        x: relax.Expr,
-        weight: relax.Expr,
-        bias: Optional[relax.Expr],
-        strides: Optional[Tuple],
-        padding: Optional[Tuple],
-        dilation: Optional[Tuple],
-        groups: Optional[Tuple],
-    ) -> relax.Var:
-        conv1d = self.block_builder.emit(
-            relax.op.nn.conv1d(
-                x,
-                weight,
-                strides=strides,
-                padding=padding,
-                dilation=dilation,
-                groups=groups,
-                data_layout="NCW",
-                kernel_layout="OIW",
-                out_dtype="float32",
-            )
-        )
-
-        if bias is None:
-            return conv1d
-        assert len(self.shape_of(bias)) == 1
-        bias = relax.op.reshape(bias, (1, -1, 1))
-        return self.block_builder.emit(relax.op.add(conv1d, bias))
-
-    def _conv1d(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        weight = args[1]
-        bias = args[2] if len(args) > 2 else None
-        stride = args[3] if len(args) > 3 else 1
-        padding = args[4] if len(args) > 4 else 0
-        dilation = args[5] if len(args) > 5 else 1
-        groups = args[6] if len(args) > 6 else 1
-        return self._conv1d_impl(
-            x,
-            weight,
-            bias=bias,
-            strides=stride,
-            padding=padding,
-            dilation=dilation,
-            groups=groups,
-        )
-
     def _conv1d_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -425,55 +203,6 @@ class TorchFXImporter(BaseFXGraphImporter):
             groups=module.groups,
         )
 
-    def _conv3d_impl(
-        self,
-        x: relax.Expr,
-        weight: relax.Expr,
-        bias: Optional[relax.Expr],
-        strides: Optional[Tuple],
-        padding: Optional[Tuple],
-        dilation: Optional[Tuple],
-        groups: Optional[Tuple],
-    ):
-        conv3d = self.block_builder.emit(
-            relax.op.nn.conv3d(
-                x,
-                weight,
-                strides=strides,
-                padding=padding,
-                dilation=dilation,
-                groups=groups,
-                data_layout="NCDHW",
-                kernel_layout="OIDHW",
-                out_dtype="float32",
-            )
-        )
-
-        if bias is None:
-            return conv3d
-        assert len(self.shape_of(bias)) == 1
-        bias = relax.op.reshape(bias, (1, -1, 1, 1, 1))
-        return self.block_builder.emit(relax.op.add(conv3d, bias))
-
-    def _conv3d(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        weight = args[1]
-        bias = args[2] if len(args) > 2 else None
-        stride = args[3] if len(args) > 3 else 1
-        padding = args[4] if len(args) > 4 else 0
-        dilation = args[5] if len(args) > 5 else 1
-        groups = args[6] if len(args) > 6 else 1
-        return self._conv3d_impl(
-            x,
-            weight,
-            bias=bias,
-            strides=stride,
-            padding=padding,
-            dilation=dilation,
-            groups=groups,
-        )
-
     def _conv3d_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -524,30 +253,6 @@ class TorchFXImporter(BaseFXGraphImporter):
             )
         )
 
-    def _einsum(self, node: fx.Node) -> relax.Var:
-        import torch  # type: ignore
-
-        args = self.retrieve_args(node)
-        operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) 
else args[1:]
-        return self.block_builder.emit(relax.op.einsum(operands, args[0]))
-
-    def _embedding_impl(
-        self,
-        x,
-        weight,
-    ) -> relax.Var:
-        x = self.block_builder.emit(relax.op.astype(x, "int32"))
-
-        ndim = x.struct_info.ndim
-        if ndim == 1:
-            return self.block_builder.emit(relax.op.take(weight, x, axis=0))
-        else:
-            x_shape = x.struct_info.shape.values
-            emb_size = weight.struct_info.shape.values[-1]
-            x = self.block_builder.emit(relax.op.reshape(x, shape=[-1]))
-            embedding = self.block_builder.emit(relax.op.take(weight, x, 
axis=0))
-            return self.block_builder.emit(relax.op.reshape(embedding, 
[*x_shape, emb_size]))
-
     def _embedding_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -655,61 +360,6 @@ class TorchFXImporter(BaseFXGraphImporter):
             )
         )
 
-    def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> 
relax.Var:
-        from torch.fx.immutable_collections import immutable_list
-        import numpy as np  # type: ignore
-
-        if isinstance(normalized_shape, (immutable_list, tuple)):
-            normalized_shape = tuple(normalized_shape)
-        else:
-            try:
-                normalized_shape = self.env[normalized_shape]
-            except TypeError:
-                normalized_shape = tuple(normalized_shape)
-
-        dim_num = len(normalized_shape)
-        axes = list(range(-dim_num, 0))
-
-        if gamma is None:
-            shape_tuple = [int(s) for s in normalized_shape]
-            gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype)
-        if beta is None:
-            shape_tuple = [int(s) for s in normalized_shape]
-            beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype)
-
-        return self.block_builder.emit(
-            relax.op.nn.layer_norm(
-                x,
-                gamma,
-                beta,
-                axes=axes,
-                epsilon=eps,
-            )
-        )
-
-    def _layer_norm(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        normalized_shape = node.args[1]
-        gamma = self.env[node.args[2]] if len(node.args) > 2 else None
-        beta = self.env[node.args[3]] if len(node.args) > 3 else None
-        eps = node.args[4] if len(node.args) > 4 else 1e-05
-        return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape)
-
-    def _layer_norm_module(self, node: fx.Node) -> relax.Var:
-        import torch  # type: ignore
-
-        x = self.env[node.args[0]]
-        module = self.named_modules[node.target]
-        normalized_shape = module.normalized_shape
-        if module.elementwise_affine:
-            gamma = self.params[module.weight]
-            beta = self.params[module.bias]
-        else:
-            gamma = relax.const(torch.ones_like(module.normalized_shape), 
x.struct_info.dtype)
-            beta = relax.const(torch.zeros_like(module.normalized_shape), 
x.struct_info.dtype)
-        eps = module.eps
-        return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape)
-
     def _linear_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -728,39 +378,6 @@ class TorchFXImporter(BaseFXGraphImporter):
 
         return self._max_pool2d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
 
-    def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
-        transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 
3])
-        query = transpose_S_H(self.env[node.args[0]])
-        key = transpose_S_H(self.env[node.args[1]])
-        value = transpose_S_H(self.env[node.args[2]])
-        attn_mask = node.args[3] if len(node.args) > 3 else 
node.kwargs.get("attn_mask", None)
-        dropout_p = node.args[4] if len(node.args) > 4 else 
node.kwargs.get("dropout_p", 0.0)
-        assert dropout_p == 0.0, "Dropout is not supported"
-        is_causal = node.args[5] if len(node.args) > 5 else 
node.kwargs.get("is_causal", False)
-        causal_mask = "TopLeft" if is_causal else None
-
-        if attn_mask is not None:
-            attn_mask = self.env[attn_mask]
-            msg = "Only a float mask is supported for the attn_mask input."
-            assert "float" in attn_mask.struct_info.dtype, msg
-
-        return self.block_builder.emit(
-            transpose_S_H(
-                relax.op.nn.attention(query, key, value, bias=attn_mask, 
causal_mask=causal_mask)
-            )
-        )
-
-    def _unbind(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
-        assert isinstance(dim, int), "Expected 2nd argument of unbind as int"
-        selections = self.shape_of(x)[dim].value
-        n_section = list(range(1, selections + 1))
-        ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, 
dim))
-        for i in range(selections):
-            ret.append(self.block_builder.emit(relax.op.squeeze(split[i], 
axis=dim)))
-        return self.block_builder.emit(relax.Tuple(ret))
-
     ########## Manipulation ##########
 
     def _cat(self, node: fx.Node) -> relax.Var:
@@ -1054,87 +671,6 @@ class TorchFXImporter(BaseFXGraphImporter):
                 return self.shape_of(self.env[node.args[0]])
         return getattr(self.env[node.args[0]], node.args[1])
 
-    def _getitem(self, node: fx.Node) -> relax.Var:
-        import torch
-
-        x = self.env[node.args[0]]
-        if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)):
-            return x[node.args[1]]
-        elif isinstance(x, relax.Var):
-            if isinstance(x.struct_info, relax.TupleStructInfo):
-                return self.block_builder.emit(relax.TupleGetItem(x, 
node.args[1]))
-
-            assert isinstance(x.struct_info, relax.TensorStructInfo)
-            take_indices = []
-            take_axes = []
-            stride_begin = []
-            stride_end = []
-            stride = []
-            stride_axes = []
-            expand_dim = []
-            i = 0
-            shape = self.shape_of(x)
-            non_ellipsis_cnt = 0
-            for index in node.args[1]:
-                if isinstance(index, (int, slice, torch.fx.Node)):
-                    non_ellipsis_cnt += 1
-            for index in node.args[1]:
-                if isinstance(index, int):
-                    stride_begin.append(index)
-                    stride_end.append(index + 1)
-                    stride.append(1)
-                    stride_axes.append(i)
-                    i = i + 1
-                elif isinstance(index, slice):
-                    stride_begin.append(0 if index.start is None else 
index.start)
-                    stride_end.append(shape[i] if index.stop is None else 
index.stop)
-                    stride.append(1 if index.step is None else index.step)
-                    stride_axes.append(i)
-                    i = i + 1
-                elif index is None:
-                    expand_dim.append(len(stride_axes) + len(expand_dim))
-                elif index is Ellipsis:
-                    for _ in range(len(shape) - non_ellipsis_cnt):
-                        stride_begin.append(0)
-                        stride_end.append(shape[i])
-                        stride.append(1)
-                        stride_axes.append(i)
-                        i += 1
-                elif isinstance(index, torch.fx.Node):
-                    node_index = self.env[index]
-                    if not isinstance(node_index, relax.Expr):
-                        raise ValueError(
-                            "Unsupported index type for relax.op.take: " + 
str(type(node_index))
-                        )
-                    take_indices.append(node_index)
-                    take_axes.append(i)
-                    i = i + 1
-                else:
-                    raise ValueError("Unsupported index type: " + 
str(type(index)))
-            while i < len(shape):
-                stride_begin.append(0)
-                stride_end.append(shape[i])
-                stride.append(1)
-                stride_axes.append(i)
-                i += 1
-            taken = x
-            if len(take_indices) > 1:
-                raise ValueError("Multiple tensors as index not yet supported")
-            for each_index, each_axis in zip(take_indices, take_axes):
-                taken = self.block_builder.emit(relax.op.take(taken, 
each_index, each_axis))
-            sliced = self.block_builder.emit(
-                relax.op.strided_slice(taken, stride_axes, stride_begin, 
stride_end, stride)
-            )
-            sliced_shape = list(self.shape_of(sliced))
-            for i in expand_dim:
-                sliced_shape.insert(i, 1)
-            return self.block_builder.emit(relax.op.reshape(sliced, 
sliced_shape))
-        elif isinstance(x, relax.Constant):
-            dtype = x.struct_info.dtype
-            return relax.const(x.data.numpy()[node.args[1]], dtype)
-        else:
-            assert False
-
     def _sym_size_int(self, node: fx.Node) -> relax.Expr:
         x = self.env[node.args[0]]
         shape = self.shape_of(x)
@@ -1182,8 +718,8 @@ class TorchFXImporter(BaseFXGraphImporter):
             nn.Conv1d: self._conv1d_module,
             nn.Conv2d: self._conv2d_module,
             nn.Conv3d: self._conv3d_module,
-            nn.ConvTranspose1d: self._conv1d_transpose_module,
-            nn.ConvTranspose2d: self._conv2d_transpose_module,
+            nn.ConvTranspose1d: self._conv_transpose1d_module,
+            nn.ConvTranspose2d: self._conv_transpose2d_module,
             nn.CrossEntropyLoss: self._cross_entropy_module,
             nn.GroupNorm: self._group_norm_module,
             nn.LayerNorm: self._layer_norm_module,
@@ -1248,8 +784,8 @@ class TorchFXImporter(BaseFXGraphImporter):
             "bmm": self._binary_op(
                 partial(relax.op.linear_algebra.matmul, out_dtype="float32"), 
operator.matmul
             ),
-            "conv_transpose1d": self._conv1d_transpose,
-            "conv_transpose2d": self._conv2d_transpose,
+            "conv_transpose1d": self._conv_transpose1d,
+            "conv_transpose2d": self._conv_transpose2d,
             "conv1d": self._conv1d,
             "conv2d": self._conv2d,
             "conv3d": self._conv3d,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 25e6dbfae3..7c887d9b96 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1156,6 +1156,59 @@ def test_binary():
     verify_model(Sub2(), example_args2, {}, expected_sub2)
 
 
+def test_batchnorm2d():
+    class BatchNorm2d(Module):
+        def __init__(self):
+            super().__init__()
+            self.bn = torch.nn.BatchNorm2d(3)
+
+        def forward(self, input):
+            return self.bn(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((3,), dtype="float32"),
+            w2: R.Tensor((3,), dtype="float32"),
+            w3: R.Tensor((3,), dtype="float32"),
+            w4: R.Tensor((3,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                ) = R.nn.batch_norm(
+                    input_1,
+                    w1,
+                    w2,
+                    w3,
+                    w4,
+                    axis=1,
+                    epsilon=1e-05,
+                    center=True,
+                    scale=True,
+                )
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+    model = BatchNorm2d().eval()
+    binding = {
+        "w1": model.bn.weight.detach().numpy(),
+        "w2": model.bn.bias.detach().numpy(),
+        "w3": model.bn.running_mean.detach().numpy(),
+        "w4": model.bn.running_var.detach().numpy(),
+    }
+    verify_model(model, example_args, binding, expected1)
+
+
 def test_adaptive_avgpool2d():
     class AdaptiveAvgPool2d0(Module):
         def __init__(self):
@@ -1165,28 +1218,594 @@ def test_adaptive_avgpool2d():
         def forward(self, input):
             return self.pool(input)
 
-    class AdaptiveAvgPool2d1(Module):
+    class AdaptiveAvgPool2d1(Module):
+        def forward(self, input):
+            return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10])
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.adaptive_avg_pool2d(
+                    input_1, output_size=[10, 10], layout="NCHW", 
out_layout="NCHW"
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1)
+    verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)
+
+
+def test_addmm():
+    class Addmm1(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x1, x2, x3):
+            return torch.addmm(x1, x2, x3)
+
+    class Addmm2(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x1, x2, x3):
+            return torch.addmm(x1, x2, x3, beta=0.8, alpha=0.5)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x1: R.Tensor((10, 10), dtype="float32"),
+            x2: R.Tensor((10, 10), dtype="float32"),
+            x3: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, 
out_dtype="float32")
+                lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            x1: R.Tensor((10, 10), dtype="float32"),
+            x2: R.Tensor((10, 10), dtype="float32"),
+            x3: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, 
out_dtype="float32")
+                lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv, 
R.const(0.5, "float32"))
+                lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1, 
R.const(0.8, "float32"))
+                lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(10, 10, dtype=torch.float32),
+        torch.randn(10, 10, dtype=torch.float32),
+        torch.randn(10, 10, dtype=torch.float32),
+    )
+
+    verify_model(Addmm1(), example_args, {}, expected1)
+    verify_model(Addmm2(), example_args, {}, expected2)
+
+
+def test_avg_pool2d():
+    class AvgPool2d1(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1])
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.avg_pool2d(
+                    input_1,
+                    pool_size=[1, 1],
+                    strides=[1, 1],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    layout="NCHW",
+                    out_layout="NCHW",
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class AvgPool2d2(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, 
padding=2, ceil_mode=True)
+
+        def forward(self, input):
+            return self.pool(input)
+
+    class AvgPool2d3(Module):
+        def forward(self, input):
+            return torch.nn.functional.avg_pool2d(
+                input, kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True
+            )
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv = R.nn.avg_pool2d(
+                    input_1,
+                    pool_size=[4, 4],
+                    strides=[2, 2],
+                    dilation=[1, 1],
+                    padding=[2, 2, 2, 2],
+                    ceil_mode=True,
+                    layout="NCHW",
+                    out_layout="NCHW",
+                )
+                gv = (lv,)
+                R.output(gv)
+            return gv
+
+    class AvgPool2d4(Module):
+        def forward(self, input):
+            return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1], 
divisor_override=2)
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv = R.nn.avg_pool2d(
+                    input_1,
+                    pool_size=[2, 1],
+                    strides=[2, 1],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    ceil_mode=False,
+                    layout="NCHW",
+                    out_layout="NCHW",
+                )
+                gv = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(AvgPool2d1(), example_args, {}, expected1)
+    verify_model(AvgPool2d2(), example_args, {}, expected2)
+    verify_model(AvgPool2d3(), example_args, {}, expected2)
+    verify_model(AvgPool2d4(), example_args, {}, expected3)
+
+
+def test_baddbmm():
+    class BAddBMM1(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, c, x, y):
+            return torch.baddbmm(c, x, y)
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((4, 128, 512), dtype="float32"),
+            inp_1: R.Tensor((4, 128, 256), dtype="float32"),
+            inp_2: R.Tensor((4, 256, 512), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, 
inp_2)
+                lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv, 
inp_0)
+                gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    class BAddBMM2(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, c, x, y):
+            return torch.baddbmm(c, x, y, alpha=2, beta=0)
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((4, 128, 512), dtype="float32"),
+            inp_1: R.Tensor((4, 128, 256), dtype="float32"),
+            inp_2: R.Tensor((4, 256, 512), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, 
inp_2)
+                lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
+                    lv, R.const(2, "float32")
+                )
+                gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    class BAddBMM3(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, c, x, y):
+            return torch.baddbmm(c, x, y, alpha=2, beta=3)
+
+    @tvm.script.ir_module
+    class Expected3:
+        @R.function
+        def main(
+            inp_0: R.Tensor((4, 128, 512), dtype="float32"),
+            inp_1: R.Tensor((4, 128, 256), dtype="float32"),
+            inp_2: R.Tensor((4, 256, 512), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, 
inp_2)
+                lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
+                    lv, R.const(2, "float32")
+                )
+                lv2: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
+                    inp_0, R.const(3, "float32")
+                )
+                lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv1, lv2)
+                gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(4, 128, 512, dtype=torch.float32),
+        torch.randn(4, 128, 256, dtype=torch.float32),
+        torch.randn(4, 256, 512, dtype=torch.float32),
+    )
+    verify_model(
+        BAddBMM1(),
+        example_args,
+        {},
+        Expected1,
+    )
+
+    verify_model(
+        BAddBMM2(),
+        example_args,
+        {},
+        Expected2,
+    )
+
+    verify_model(
+        BAddBMM3(),
+        example_args,
+        {},
+        Expected3,
+    )
+
+
+def test_bmm():
+    class BMM(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x, y):
+            return torch.bmm(x, y)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            input_1: R.Tensor((4, 128, 256), dtype="float32"),
+            input_2: R.Tensor((4, 256, 512), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
+                    input_1, input_2, out_dtype="float32"
+                )
+                gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(4, 128, 256, dtype=torch.float32),
+        torch.randn(4, 256, 512, dtype=torch.float32),
+    )
+    verify_model(
+        BMM(),
+        example_args,
+        {},
+        Expected,
+    )
+
+
+def test_conv_transpose1d():
+    class ConvTranspose1d1(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=True)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    class ConvTranspose1d1Func(Module):
+        def __init__(self):
+            super().__init__()
+            self.weight = torch.randn(size=[6, 6, 3])
+            self.bias = torch.randn(size=[6])
+
+        def forward(self, input):
+            return torch.nn.functional.conv_transpose1d(input, self.weight, 
self.bias)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 6, 4), dtype="float32"),
+            w1: R.Tensor((6, 6, 3), dtype="float32"),
+            w2: R.Tensor((6,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 6), dtype="float32") = 
R.nn.conv1d_transpose(
+                    input_1,
+                    w1,
+                    strides=[1],
+                    padding=[0, 0],
+                    dilation=[1],
+                    data_layout="NCW",
+                    kernel_layout="OIW",
+                    out_layout="NCW",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
+                lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2)
+                gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    class ConvTranspose1d2(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=False)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 6, 4), dtype="float32"),
+            w1: R.Tensor((6, 6, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 6), dtype="float32") = 
R.nn.conv1d_transpose(
+                    input_1,
+                    w1,
+                    strides=[1],
+                    padding=[0, 0],
+                    dilation=[1],
+                    data_layout="NCW",
+                    kernel_layout="OIW",
+                    out_layout="NCW",
+                    out_dtype="float32",
+                )
+                gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 6, 4, dtype=torch.float32),)
+
+    model = ConvTranspose1d1()
+    binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+    model = ConvTranspose1d1Func()
+    binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+    model = ConvTranspose1d2()
+    binding = {"w1": model.conv.weight.detach().numpy()}
+    verify_model(model, example_args, binding, expected2)
+
+
+def test_conv_transpose2d():
+    class ConvTranspose2d1(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=True)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    class ConvTranspose2d1Func(Module):
+        def __init__(self):
+            super().__init__()
+            self.weight = torch.randn(size=[3, 3, 7, 7])
+            self.bias = torch.randn(size=[3])
+
+        def forward(self, input):
+            return torch.nn.functional.conv_transpose2d(input, self.weight, 
self.bias)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((3, 3, 7, 7), dtype="float32"),
+            w2: R.Tensor((3,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = 
R.nn.conv2d_transpose(
+                    input_1,
+                    w1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1])
+                lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, 
lv2)
+                gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    class ConvTranspose2d2(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=False)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((3, 3, 7, 7), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = 
R.nn.conv2d_transpose(
+                    input_1,
+                    w1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+    model = ConvTranspose2d1()
+    binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+    model = ConvTranspose2d1Func()
+    binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+    model = ConvTranspose2d2()
+    binding = {"w1": model.conv.weight.detach().numpy()}
+    verify_model(model, example_args, binding, expected2)
+
+
+def test_conv1d():
+    class Conv1D1(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv1d(3, 6, 7, bias=True)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    class Conv1D1Func(Module):
+        def __init__(self):
+            super().__init__()
+            self.weight = torch.randn(size=[6, 3, 7])
+            self.bias = torch.randn(size=[6])
+
+        def forward(self, input):
+            return torch.nn.functional.conv1d(input, self.weight, self.bias)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            w1: R.Tensor((6, 3, 7), dtype="float32"),
+            w2: R.Tensor((6,), dtype="float32"),
+            input_1: R.Tensor((1, 3, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d(
+                    input_1,
+                    w1,
+                    strides=[1],
+                    padding=[0, 0],
+                    dilation=[1],
+                    data_layout="NCW",
+                    kernel_layout="OIW",
+                    out_layout="NCW",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 
6, 1])
+                lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2)
+                gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    class Conv1D2(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)
+
         def forward(self, input):
-            return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10])
+            return self.conv(input)
 
     @tvm.script.ir_module
-    class expected1:
+    class expected2:
         @R.function
         def main(
-            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
-        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            w1: R.Tensor((6, 3, 7), dtype="float32"),
+            input_1: R.Tensor((1, 3, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")):
             # block 0
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.adaptive_avg_pool2d(
-                    input_1, output_size=[10, 10], layout="NCHW", 
out_layout="NCHW"
+                lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d(
+                    input_1,
+                    w1,
+                    strides=[1],
+                    padding=[0, 0],
+                    dilation=[1],
+                    data_layout="NCW",
+                    kernel_layout="OIW",
+                    out_layout="NCW",
+                    out_dtype="float32",
                 )
-                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv1,)
                 R.output(gv)
             return gv
 
-    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1)
-    verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)
+    example_args = (torch.randn(1, 3, 10, dtype=torch.float32),)
+
+    model = Conv1D1()
+    binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+    model = Conv1D1Func()
+    binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+    model = Conv1D2()
+    binding = {"w1": model.conv.weight.detach().numpy()}
+    verify_model(model, example_args, binding, expected2)
 
 
 def test_conv2d():
@@ -1281,6 +1900,267 @@ def test_conv2d():
     verify_model(model, example_args, binding, expected2)
 
 
+def test_conv3d():
+    class Conv3D1(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv3d(3, 6, 7, bias=True)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    class Conv3D1Func(Module):
+        def __init__(self):
+            super().__init__()
+            self.weight = torch.randn(size=[6, 3, 7, 7, 7])
+            self.bias = torch.randn(size=[6])
+
+        def forward(self, input):
+            return torch.nn.functional.conv3d(input, self.weight, self.bias)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"),
+            w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"),
+            w2: R.Tensor((6,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d(
+                    input_1,
+                    w1,
+                    strides=[1],
+                    padding=[0, 0, 0],
+                    dilation=[1],
+                    data_layout="NCDHW",
+                    kernel_layout="OIDHW",
+                    out_layout="NCDHW",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1])
+                lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, 
lv2)
+                gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) = 
(lv3,)
+                R.output(gv)
+            return gv
+
+    class Conv3D2(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv3d(3, 6, 7, bias=False)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"),
+            w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d(
+                    input_1,
+                    w1,
+                    strides=[1],
+                    padding=[0, 0, 0],
+                    dilation=[1],
+                    data_layout="NCDHW",
+                    kernel_layout="OIDHW",
+                    out_layout="NCDHW",
+                    out_dtype="float32",
+                )
+                gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) = 
(lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),)
+
+    model = Conv3D1()
+    binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+    model = Conv3D1Func()
+    binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+    model = Conv3D2()
+    binding = {"w1": model.conv.weight.detach().numpy()}
+    verify_model(model, example_args, binding, expected2)
+
+
+def test_einsum():
+    class Einsum1(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return torch.einsum("ii", x)
+
+    class Einsum2(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x, y):
+            return torch.einsum("i,j->ij", x, y)
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((4, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), 
subscripts="ii")
+                gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor((5, 4), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((5, 4), dtype="float32") = R.einsum(
+                    (inp_0, inp_1), subscripts="i,j->ij"
+                )
+                gv: R.Tuple(R.Tensor((5, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(4, 4, dtype=torch.float32),)
+    verify_model(Einsum1(), example_args, {}, Expected1)
+
+    example_args = (torch.randn(5, dtype=torch.float32), torch.randn(4, 
dtype=torch.float32))
+    verify_model(Einsum2(), example_args, {}, Expected2)
+
+
+def test_embedding():
+    class Embedding(Module):
+        def __init__(self):
+            super().__init__()
+            self.embedding = torch.nn.Embedding(10, 3)
+
+        def forward(self, input):
+            return self.embedding(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((4,), dtype="int32") = R.astype(input_1, 
dtype="int32")
+                lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0)
+                gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randint(low=-int(1e5), high=int(1e5), size=(4,), 
dtype=torch.int64),)
+
+    model = Embedding()
+    binding = {"w1": model.embedding.weight.detach().numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+
+def test_groupnorm():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    class GroupNorm(Module):
+        def __init__(self):
+            super().__init__()
+            self.gn = torch.nn.GroupNorm(3, 3)
+
+        def forward(self, input):
+            return self.gn(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((3,), dtype="float32"),
+            w2: R.Tensor((3,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.group_norm(
+                    input_1,
+                    w1,
+                    w2,
+                    num_groups=3,
+                    channel_axis=1,
+                    axes=[2, 3],
+                    epsilon=1.0000000000000001e-05,
+                    center=True,
+                    scale=True,
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+    model = GroupNorm()
+    binding = {
+        "w1": model.gn.weight.detach().numpy(),
+        "w2": model.gn.bias.detach().numpy(),
+    }
+    verify_model(model, example_args, binding, expected1)
+
+
+def test_layernorm():
+    class LayerNorm(Module):
+        def __init__(self):
+            super().__init__()
+            self.ln = torch.nn.LayerNorm((10, 10))
+
+        def forward(self, input):
+            return self.ln(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((10, 10), dtype="float32"),
+            w2: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.layer_norm(
+                    input_1,
+                    w1,
+                    w2,
+                    axes=[-2, -1],
+                    epsilon=1e-05,
+                    center=True,
+                    scale=True,
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+    model = LayerNorm()
+    binding = {
+        "w1": model.ln.weight.detach().numpy(),
+        "w2": model.ln.bias.detach().numpy(),
+    }
+    verify_model(LayerNorm(), example_args, binding, expected1)
+
+
 def test_linear():
     class Dense1(Module):
         def __init__(self):
@@ -1460,6 +2340,254 @@ def test_maxpool2d():
     verify_model(MaxPool2d3(), example_args, {}, expected3)
 
 
+def test_scaled_dot_product_attention():
+    class Attention1(Module):
+        def forward(self, q, k, v):
+            return torch.nn.functional.scaled_dot_product_attention(q, k, v)
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    inp_0, axes=[0, 2, 1, 3]
+                )
+                lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    inp_1, axes=[0, 2, 1, 3]
+                )
+                lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    inp_2, axes=[0, 2, 1, 3]
+                )
+                lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.nn.attention(
+                    lv, lv1, lv2, scale=None
+                )
+                lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.permute_dims(
+                    lv3, axes=[0, 2, 1, 3]
+                )
+                gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = 
(lv4,)
+                R.output(gv)
+            return gv
+
+    class Attention2(Module):
+        def forward(self, q, k, v, mask):
+            return torch.nn.functional.scaled_dot_product_attention(q, k, v, 
mask)
+
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    inp_0, axes=[0, 2, 1, 3]
+                )
+                lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    inp_1, axes=[0, 2, 1, 3]
+                )
+                lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
+                    inp_2, axes=[0, 2, 1, 3]
+                )
+                lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.nn.attention(
+                    lv, lv1, lv2, inp_3, scale=None
+                )
+                lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.permute_dims(
+                    lv3, axes=[0, 2, 1, 3]
+                )
+                gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = 
(lv4,)
+                R.output(gv)
+            return gv
+
+    verify_model(
+        Attention1(),
+        (
+            torch.randn(32, 8, 128, 64, dtype=torch.float32),
+            torch.randn(32, 8, 128, 64, dtype=torch.float32),
+            torch.randn(32, 8, 128, 64, dtype=torch.float32),
+        ),
+        {},
+        Expected1,
+    )
+
+    verify_model(
+        Attention2(),
+        (
+            torch.randn(32, 8, 128, 64, dtype=torch.float32),
+            torch.randn(32, 8, 128, 64, dtype=torch.float32),
+            torch.randn(32, 8, 128, 64, dtype=torch.float32),
+            torch.randn(32, 8, 128, 128, dtype=torch.float32),
+        ),
+        {},
+        Expected2,
+    )
+
+
+def test_unbind():
+    class Unbind1(Module):
+        def forward(self, data):
+            return torch.unbind(data)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(
+            R.Tensor((3, 10, 10), dtype="float32"),
+            R.Tensor((3, 10, 10), dtype="float32"),
+            R.Tensor((3, 10, 10), dtype="float32"),
+        ):
+            # block 0
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                    R.Tensor((0, 3, 10, 10), dtype="float32"),
+                ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0)
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
+                lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[0])
+                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1]
+                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, 
axis=[0])
+                lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2]
+                lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, 
axis=[0])
+                lv7: R.Tuple(
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                ) = (lv2, lv4, lv6)
+                lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
+                lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
+                lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+                gv: R.Tuple(
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                ) = (lv8, lv9, lv10)
+                R.output(gv)
+            return gv
+
+    class Unbind2(Module):
+        def forward(self, data):
+            return torch.unbind(data, dim=1)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(
+            R.Tensor((3, 10, 10), dtype="float32"),
+            R.Tensor((3, 10, 10), dtype="float32"),
+            R.Tensor((3, 10, 10), dtype="float32"),
+        ):
+            # block 0
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((3, 1, 10, 10), dtype="float32"),
+                    R.Tensor((3, 1, 10, 10), dtype="float32"),
+                    R.Tensor((3, 1, 10, 10), dtype="float32"),
+                    R.Tensor((3, 0, 10, 10), dtype="float32"),
+                ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1)
+                lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0]
+                lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[1])
+                lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1]
+                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, 
axis=[1])
+                lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2]
+                lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, 
axis=[1])
+                lv7: R.Tuple(
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                ) = (lv2, lv4, lv6)
+                lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
+                lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
+                lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+                gv: R.Tuple(
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                    R.Tensor((3, 10, 10), dtype="float32"),
+                ) = (lv8, lv9, lv10)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
+    verify_model(Unbind1(), example_args, {}, expected1)
+    verify_model(Unbind2(), example_args, {}, expected2)
+
+
+def test_interpolate():
+    class InterpolateBilinear(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(input, (224, 224), 
mode="bilinear")
+
+    @tvm.script.ir_module
+    class expected_bilinear:
+        @R.function
+        def main(
+            input: R.Tensor((1, 3, 112, 112), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 224, 224), dtype="float32") = 
R.image.resize2d(
+                    input,
+                    R.shape([224, 224]),
+                    roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), 
T.float32(0.0)],
+                    layout="NCHW",
+                    method="linear",
+                    coordinate_transformation_mode="half_pixel",
+                    rounding_method="round",
+                    cubic_alpha=-0.5,
+                    cubic_exclude=0,
+                    extrapolation_value=0.0,
+                    out_dtype="void",
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = 
(lv,)
+                R.output(gv)
+            return gv
+
+    class InterpolateNearest(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(input, (224, 224), 
mode="nearest")
+
+    @tvm.script.ir_module
+    class expected_nearest:
+        @R.function
+        def main(
+            input: R.Tensor((1, 3, 112, 112), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 224, 224), dtype="float32") = 
R.image.resize2d(
+                    input,
+                    R.shape([224, 224]),
+                    roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), 
T.float32(0.0)],
+                    layout="NCHW",
+                    method="nearest_neighbor",
+                    coordinate_transformation_mode="half_pixel",
+                    rounding_method="round",
+                    cubic_alpha=-0.5,
+                    cubic_exclude=0,
+                    extrapolation_value=0.0,
+                    out_dtype="void",
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = 
(lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),)
+    verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear)
+    verify_model(InterpolateNearest(), example_args, {}, expected_nearest)
+
+
 def test_mean():
     class Mean(Module):
         def forward(self, input):


Reply via email to