This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 877fbf1867 [Unity] Initial PyTorch Frontend (#14037)
877fbf1867 is described below

commit 877fbf186718a241dbb1730eacec3c0b8496765d
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Feb 18 11:19:48 2023 -0500

    [Unity] Initial PyTorch Frontend (#14037)
    
    [Unity] Initial PyTorch Frontend
    
    This PR introduces initial pytorch frontend components of Relax, including
    - a FX translator that translates a Torch FX graph module to an TVM 
IRModule,
    - a Relax-backend of Torch Dynamo, which brings the mechanism to build 
PyTorch model using Relax compilation pipeline,
    - a pipeline prototype that contains the collection of pre-defined 
pipelines that optimizes and lower IRModule before passing to minimum build.
    
    Co-authored-by: Bohan Hou 
<[email protected]>
    Co-authored-by: Tianqi Chen <[email protected]>
    Co-authored-by: Siyuan Feng <[email protected]>
---
 python/tvm/relax/__init__.py                     |    3 +
 python/tvm/relax/frontend/__init__.py            |   19 +
 python/tvm/relax/frontend/torch/__init__.py      |   21 +
 python/tvm/relax/frontend/torch/dynamo.py        |  156 ++
 python/tvm/relax/frontend/torch/fx_translator.py |  820 ++++++++++
 python/tvm/relax/pipeline.py                     |   84 ++
 tests/python/relax/test_frontend_dynamo.py       |  198 +++
 tests/python/relax/test_frontend_from_fx.py      | 1729 ++++++++++++++++++++++
 tests/python/relax/test_pipeline.py              |   45 +
 9 files changed, 3075 insertions(+)

diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index cfcf7876dc..33a9c2eece 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -73,6 +73,9 @@ from .struct_info import (
     FuncStructInfo,
 )
 
+# pipeline
+from .pipeline import get_pipeline
+
 # Import submodules in the last to avoid dependency
 from . import exec_builder
 from . import expr
diff --git a/python/tvm/relax/frontend/__init__.py 
b/python/tvm/relax/frontend/__init__.py
new file mode 100644
index 0000000000..6c9c188aaa
--- /dev/null
+++ b/python/tvm/relax/frontend/__init__.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Frontends for constructing Relax programs, with the model importers
+"""
diff --git a/python/tvm/relax/frontend/torch/__init__.py 
b/python/tvm/relax/frontend/torch/__init__.py
new file mode 100644
index 0000000000..55da5a456d
--- /dev/null
+++ b/python/tvm/relax/frontend/torch/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+PyTorch Frontends for constructing Relax programs, with the model importers
+"""
+from .fx_translator import from_fx
+from .dynamo import relax_dynamo, dynamo_capture_subgraphs
diff --git a/python/tvm/relax/frontend/torch/dynamo.py 
b/python/tvm/relax/frontend/torch/dynamo.py
new file mode 100644
index 0000000000..94de73a431
--- /dev/null
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, missing-function-docstring, not-callable
+# pylint: disable=import-outside-toplevel, unused-argument
+# mypy: ignore-errors
+"""PyTorch Dynamo backend of Relax."""
+import functools
+from typing import Optional
+
+import tvm
+from tvm.relax.vm import build as relax_build
+from tvm.relax.frontend.torch.fx_translator import from_fx
+
+
+def device_from_inputs(example_inputs):
+    for x in example_inputs:
+        if hasattr(x, "device"):
+            return x.device
+    return None
+
+
+def relax_dynamo(pipeline: Optional[tvm.transform.Pass] = None):
+    """A helper function to create a relax backend.
+
+    Parameters
+    ----------
+    pipeline : Optional[tvm.transform.Pass]
+        The pipeline to be applied to the relax module before sent to build.
+
+    Returns
+    -------
+    backend : Callable[[torch.fx.GraphModule, List[torch.Tensor]], Callable]
+        The relax dynamo backend.
+    """
+
+    def _relax_backend(graph_module, example_inputs):
+        import torch  # type: ignore[import]
+
+        assert isinstance(graph_module, torch.fx.GraphModule)
+
+        def to_torch_tensor(nd_tensor):
+            """A helper function to transfer a NDArray to torch.tensor."""
+            if isinstance(nd_tensor, tvm.nd.NDArray):
+                return torch.from_numpy(nd_tensor.numpy())
+            elif isinstance(nd_tensor, tvm.ir.Array):
+                return tuple(to_torch_tensor(x) for x in nd_tensor)
+            else:
+                raise ValueError(f"Unsupported type {type(nd_tensor)}")
+
+        def to_tvm_tensor(torch_tensor):
+            """A helper function to transfer a torch.tensor to NDArray."""
+            if not isinstance(torch_tensor, 
torch._subclasses.fake_tensor.FakeTensor):
+                return tvm.nd.array(torch_tensor.numpy())
+            # Fake Tensor
+            real_tensor = torch.randn(torch_tensor.shape, 
dtype=torch_tensor.dtype)
+            return tvm.nd.array(real_tensor.numpy())
+
+        device = device_from_inputs(example_inputs)
+        input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in 
example_inputs]
+        mod = from_fx(graph_module, input_info)
+
+        if device.type == "cuda":
+            dev = tvm.cuda(device.index)
+            target = tvm.target.cuda()
+        else:
+            dev = tvm.cpu(0)
+            target = tvm.target.Target(llvm_target())
+
+        # invoke optimization pipeline.
+        if pipeline is None:
+            # get default pipeline
+            seq = tvm.relax.get_pipeline()
+        elif isinstance(pipeline, str):
+            # lookup by name
+            seq = tvm.relax.get_pipeline(pipeline)
+        else:
+            seq = pipeline
+
+        mod = mod.with_attr("target", target)
+        mod = seq(mod)
+
+        ex = relax_build(mod, target=target)
+
+        vm = tvm.relax.vm.VirtualMachine(exec=ex.mod, device=dev)
+
+        def exec_tvm(*i_args):
+            args = [a.contiguous() for a in i_args]
+            vm_args = list()
+            for arg in args:
+                if arg.dim() != 0:
+                    if arg.requires_grad:
+                        arg = arg.detach()
+                    vm_args.append(to_tvm_tensor(arg))
+            outputs = vm["main"](*vm_args)
+            return to_torch_tensor(outputs)
+
+        return exec_tvm
+
+    return _relax_backend
+
+
+def dynamo_capture_subgraphs(model, *params) -> tvm.ir.IRModule:
+    """Capture subgraphs of the PyTorch model using torch.compile into an 
IRModule.
+
+    Parameters
+    ----------
+    model : torch.nn.Module
+        The PyTorch model to be captured.
+
+    params : List[torch.Tensor]
+        The parameters of the PyTorch model.
+
+    Returns
+    -------
+    mod : tvm.ir.IRModule
+        The IRModule that contains captured subgraphs.
+    """
+    import torch  # type: ignore[import]
+    from torch import fx  # type: ignore[import]
+    from torch import _dynamo as dynamo  # type: ignore[import]
+
+    mod = tvm.IRModule()
+
+    def _capture(graph_module: fx.GraphModule, example_inputs):
+        assert isinstance(graph_module, torch.fx.GraphModule)
+        input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in 
example_inputs]
+        subgraph = from_fx(graph_module, input_info)
+        mod["subgraph_" + str(len(mod.get_global_vars()))] = subgraph["main"]
+        return graph_module.forward
+
+    dynamo.reset()
+    compiled_model = torch.compile(model, backend=_capture)
+    compiled_model(*params)
+    return mod
+
+
[email protected]_cache(None)
+def llvm_target():
+    if "avx512" in open("/proc/cpuinfo").read():
+        return "llvm -mcpu=skylake-avx512"
+    return "llvm -mcpu=core-avx2"
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
new file mode 100644
index 0000000000..582f2edbcf
--- /dev/null
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -0,0 +1,820 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, inconsistent-return-statements, 
unidiomatic-typecheck
+# pylint: disable=import-outside-toplevel
+"""PyTorch FX frontend of Relax."""
+from typing import Callable, Dict, List, Tuple, Union
+from functools import reduce
+
+import tvm
+from tvm import relax
+
+
+class TorchFXImporter:
+    """An importer from PyTorch FX to Relax."""
+
+    import torch  # type: ignore
+    from torch import fx
+
+    def __init__(self) -> None:
+        import torch  # type: ignore
+        from torch import fx
+
+        self.env: Dict[fx.node.Node, relax.Expr] = {}
+        self.params: Dict[torch.Tensor, relax.Constant] = {}
+        self.named_modules: Dict[str, torch.Module] = None
+        self.block_builder: relax.BlockBuilder = None
+        self.create_convert_map()
+
+    ########## Utilities ##########
+    @staticmethod
+    def _fetch_attr(model, target: str):
+        import torch  # type: ignore
+
+        target_atoms = target.split(".")
+        attr_itr = model
+        for i, atom in enumerate(target_atoms):
+            if not hasattr(attr_itr, atom):
+                raise RuntimeError(
+                    f"Node referenced non existing target 
{'.'.join(target_atoms[:i])}"
+                )
+            attr_itr = getattr(attr_itr, atom)
+        if isinstance(attr_itr, torch.Tensor):
+            return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr)
+        return attr_itr
+
+    @staticmethod
+    def _convert_data_type(input_type):
+        """converts the PyTorch scalar type input_type to a TVM dtype."""
+        import torch  # type: ignore
+
+        input_type = input_type.lower()
+        if input_type in ["float", "float32", "torch.float32", torch.float32]:
+            return "float32"
+        elif input_type in ["float16", "torch.float16", torch.float16]:
+            return "float16"
+        elif input_type in ["int64", "torch.int64", torch.int64]:
+            return "int64"
+        else:
+            raise NotImplementedError("input_type {} is not handled 
yet".format(input_type))
+
+    @staticmethod
+    def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var:
+        tensor = tensor.detach().cpu()
+        shape = tensor.data.shape
+        dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype))
+        return relax.const(tensor.data.numpy(), relax.TensorStructInfo(shape, 
dtype))
+
+    @staticmethod
+    def shape_of(tensor):
+        """Get the shape of a tensor."""
+        import torch  # type: ignore
+
+        if isinstance(tensor, relax.Expr):
+            if not isinstance(tensor.struct_info, relax.TensorStructInfo):
+                raise TypeError("The input Expr of shape_of should be a 
Tensor")
+            return tensor.struct_info.shape
+        elif isinstance(tensor, torch.Tensor):
+            return tensor.shape
+        raise ValueError("Unsupported type: {}".format(type(tensor)))
+
+    def retrieve_args(self, node):
+        return self._retrieve_args(node.args)
+
+    def _retrieve_args(self, node):
+        from torch import fx
+
+        if isinstance(node, fx.node.Node):
+            return self.env[node]
+        elif isinstance(node, tuple):
+            return tuple(self._retrieve_args(x) for x in node)
+        elif isinstance(node, list):
+            return [self._retrieve_args(x) for x in node]
+        elif isinstance(node, dict):
+            return {self._retrieve_args(k): self._retrieve_args(v) for k, v in 
node.items()}
+        else:
+            return node
+
+    @staticmethod
+    def _promote_binary_op_args(lhs, rhs):
+        if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
+            return lhs, rhs
+        elif isinstance(lhs, relax.Expr):
+            assert isinstance(lhs.struct_info, relax.TensorStructInfo)
+            return lhs, relax.const(rhs, lhs.struct_info.dtype)
+        elif isinstance(rhs, relax.Expr):
+            assert isinstance(rhs.struct_info, relax.TensorStructInfo)
+            return relax.const(lhs, rhs.struct_info.dtype), rhs
+        else:
+            assert False
+
+    def _call_binary_op(self, op, lhs, rhs):
+        lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs)
+        return self.block_builder.emit(op(lhs, rhs))
+
+    ########## Arithmetic ##########
+
+    def _cos(self, node: fx.node.Node) -> relax.Var:
+        return self.block_builder.emit(relax.op.cos(self.env[node.args[0]]))
+
+    def _sin(self, node: fx.node.Node) -> relax.Var:
+        return self.block_builder.emit(relax.op.sin(self.env[node.args[0]]))
+
+    def _sqrt(self, node: fx.node.Node) -> relax.Expr:
+        arg = self.env[node.args[0]]
+        if isinstance(arg, (int, float)):
+            arg = relax.const(arg, "float32")
+        return self.block_builder.emit(relax.op.sqrt(arg))
+
+    def _add(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+            return self._call_binary_op(relax.op.add, lhs, rhs)
+        return lhs + rhs
+
+    def _floordiv(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+            return self._call_binary_op(relax.op.floor_divide, lhs, rhs)
+        return lhs // rhs
+
+    def _mul(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+            return self._call_binary_op(relax.op.multiply, lhs, rhs)
+        return lhs * rhs
+
+    def _sub(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+            return self._call_binary_op(relax.op.subtract, lhs, rhs)
+        return lhs - rhs
+
+    def _truediv(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+            return self._call_binary_op(relax.op.divide, lhs, rhs)
+        return lhs / rhs
+
+    def _clamp(self, node: fx.node.Node) -> relax.Expr:
+        args = self.retrieve_args(node)
+        a_min = node.kwargs["min"]
+        a_max = node.kwargs["max"]
+        if not isinstance(a_min, (int, float)):
+            raise ValueError(
+                f"TVM only supports constant min value for torch.clamp/clip, "
+                f"but got {a_min} with type {type(a_min)}"
+            )
+        if not isinstance(a_max, (int, float)):
+            raise ValueError(
+                f"TVM only supports constant max value for torch.clamp/clip, "
+                f"but got {a_max} with type {type(a_max)}"
+            )
+        return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
+
+    ########## Compare ##########
+
+    def _lt(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        return self._call_binary_op(relax.op.less, lhs, rhs)
+
+    ########## Creation ##########
+
+    def _tril(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        k = node.args[1] if len(node.args) > 1 else 0
+        assert isinstance(k, int)
+        return self.block_builder.emit(relax.op.create.tril(x, k))
+
+    def _new_ones(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        self_var = args[0]
+        size = args[1:]
+        if not isinstance(size, (list, tuple)):
+            size = (size,)
+        size = relax.ShapeExpr(size)
+        return self.block_builder.emit(
+            relax.op.full(
+                size,
+                relax.const(1, self_var.struct_info.dtype),
+                self_var.struct_info.dtype,
+            )
+        )
+
+    ########## Statistical ##########
+
+    def _sum(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        if len(args) == 1:
+            return self.block_builder.emit(relax.op.sum(args[0]))
+        return self.block_builder.emit(relax.op.sum(args[0], args[1]))
+
+    ########## DataType ##########
+
+    def _float(self, node: fx.node.Node) -> relax.Var:
+        return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"float32"))
+
+    def _half(self, node: fx.node.Node) -> relax.Var:
+        return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"float16"))
+
+    def _type(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        return self.block_builder.emit(relax.op.astype(args[0], args[1]))
+
+    ########## Linear Algebra ##########
+
+    def _matmul_impl(self, a: relax.Expr, b: relax.Expr):
+        return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, 
out_dtype="float32"))
+
+    def _matmul(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        res = self._matmul_impl(
+            args[0],
+            args[1],
+        )
+        return res
+
+    def _addmm(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        y = self.env[node.args[1]]
+        z = self.env[node.args[2]]
+        matmul = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, 
out_dtype="float32"))
+        return self.block_builder.emit(relax.op.add(x, matmul))
+
+    ########## Manipulation ##########
+
+    def _cat(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        return self.block_builder.emit(relax.op.concat(args[0], 
axis=node.kwargs["dim"]))
+
+    def _expand(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        return self.block_builder.emit(relax.op.broadcast_to(args[0], 
args[1:]))
+
+    def _flatten(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        if node.target in self.named_modules:
+            module = self.named_modules[node.target]
+            start_dim = module.start_dim
+            end_dim = module.end_dim
+        else:
+            start_dim = node.args[1] if len(node.args) >= 2 else 0
+            end_dim = node.args[2] if len(node.args) == 3 else -1
+        shape = self.shape_of(x)
+        start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim
+        end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim
+        flattened = reduce(lambda x, y: x * y, [shape[i] for i in 
range(start_dim, end_dim + 1)])
+        new_shape = (
+            [shape[i] for i in range(0, start_dim)]
+            + [flattened]
+            + [shape[i] for i in range(end_dim + 1, len(shape))]
+        )
+        return self.block_builder.emit(relax.op.reshape(x, new_shape))
+
+    def _permute(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        return self.block_builder.emit(relax.op.permute_dims(args[0], 
args[1:]))
+
+    def _reshape(self, node: fx.node.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        args = self.retrieve_args(node)
+        if isinstance(args[1], (torch.Size, tuple, list)):
+            return self.block_builder.emit(relax.op.reshape(args[0], 
tuple(args[1])))
+        return self.block_builder.emit(relax.op.reshape(args[0], args[1:]))
+
+    def _split(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        split_size = node.args[1]
+        if "dim" in node.kwargs:
+            dim = node.kwargs["dim"]
+        else:
+            dim = 0
+        n_section = (self.shape_of(x)[dim].value + split_size - 1) // 
split_size
+        return self.block_builder.emit(relax.op.split(x, n_section, dim))
+
+    def _transpose(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        full_idx = list(range(len(self.shape_of(args[0]))))
+        full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], 
full_idx[args[1]]
+        return self.block_builder.emit(relax.op.permute_dims(args[0], 
full_idx))
+
+    ########## Neural Network ##########
+
+    def _linear(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        weight = self.params[module.weight]
+        bias = None if module.bias is None else self.params[module.bias]
+        return self.block_builder.emit(relax.op.linear(x, weight, bias, 
"float32"))
+
+    def _conv2d(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        weight = self.params[module.weight]
+
+        conv2d = self.block_builder.emit(
+            relax.op.nn.conv2d(
+                x,
+                weight,
+                strides=module.stride,
+                padding=module.padding,
+                dilation=module.dilation,
+                groups=module.groups,
+                data_layout="NCHW",
+                kernel_layout="OIHW",
+                out_dtype="float32",
+            )
+        )
+
+        if module.bias is None:
+            return conv2d
+
+        bias = self.params[module.bias]
+        assert len(self.shape_of(bias)) == 1
+        bias = relax.op.reshape(bias, (1, -1, 1, 1))
+
+        return self.block_builder.emit(relax.op.add(conv2d, bias))
+
+    def _max_pool2d(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        if node.target in self.named_modules:
+            module = self.named_modules[node.target]
+            kernel = module.kernel_size
+            stride = module.stride
+            padding = module.padding
+            dilation = module.dilation
+            ceil_mode = module.ceil_mode
+        else:
+            nargs = len(node.args)
+            kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"]
+            stride = node.args[2] if nargs > 2 else node.kwargs["stride"]
+            padding = node.args[3] if nargs > 3 else node.kwargs["padding"]
+            dilation = node.args[4] if nargs > 4 else node.kwargs["dilation"]
+            ceil_mode = node.args[5] if nargs > 5 else node.kwargs["ceil_mode"]
+
+        stride = kernel if stride is None else stride
+
+        return self.block_builder.emit(
+            relax.op.nn.max_pool2d(
+                x,
+                pool_size=kernel,
+                strides=stride,
+                padding=padding,
+                dilation=dilation,
+                layout="NCHW",
+                ceil_mode=ceil_mode,
+            )
+        )
+
+    def _adaptive_avg_pool2d(self, is_module: bool) -> Callable:
+        from torch import fx
+
+        def _impl(node: fx.node.Node) -> relax.Var:
+            if is_module:
+                module = self.named_modules[node.target]
+                x = self.env[node.args[0]]
+                output_size = module.output_size
+            else:
+                x = self.env[node.args[0]]
+                output_size = node.args[1]
+            return self.block_builder.emit(
+                relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
+            )
+
+        return _impl
+
+    def _softmax(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        if node.target in self.named_modules:
+            module = self.named_modules[node.target]
+            dim = module.dim
+        else:
+            nargs = len(node.args)
+            dim = node.args[1] if nargs > 1 else node.kwargs["dim"]
+        assert dim is not None
+        return self.block_builder.emit(relax.op.nn.softmax(x, dim))
+
+    def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        weight = self.params[module.weight]
+        bias = self.params[module.bias]
+        dtype = self._convert_data_type(str(module.running_mean.dtype))
+        running_mean = relax.const(module.running_mean.cpu().detach().numpy(), 
dtype)
+        running_var = relax.const(module.running_var.cpu().detach().numpy(), 
dtype)
+        eps = module.eps
+
+        res_tuple = self.block_builder.emit(
+            relax.op.nn.batch_norm(
+                x,
+                weight,
+                bias,
+                running_mean,
+                running_var,
+                axis=1,
+                epsilon=eps,
+            )
+        )
+
+        return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0))
+
+    def _layer_norm(self, node: fx.node.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+
+        if module.elementwise_affine:
+            gamma = self.params[module.weight]
+            beta = self.params[module.bias]
+        else:
+            gamma = relax.const(torch.ones_like(module.normalized_shape), 
x.checked_type)
+            beta = relax.const(torch.zeros_like(module.normalized_shape), 
x.checked_type)
+        dim_num = len(module.normalized_shape)
+        axes = list(range(-dim_num, 0))
+
+        return self.block_builder.emit(
+            relax.op.nn.layer_norm(
+                x,
+                gamma,
+                beta,
+                axes=axes,
+                epsilon=module.eps,
+            )
+        )
+
+    def _group_norm(self, node: fx.node.Node) -> relax.Var:
+        # torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05,
+        #                    affine=True, device=None, dtype=None)
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        num_groups = module.num_groups
+        num_channels = module.num_channels
+        eps = module.eps
+        affine = module.affine
+
+        shape = self.shape_of(x)
+        assert len(shape) == 4
+        N, C, H, W = shape[0], shape[1], shape[2], shape[3]
+        assert C == num_channels
+        assert C % num_groups == 0
+        grouped_x = self.block_builder.emit(
+            relax.op.reshape(x, [N, num_groups, C // num_groups, H, W])
+        )
+        mean_x = self.block_builder.emit(relax.op.mean(grouped_x, [2, 3, 4], 
keepdims=True))
+        sub_x = self.block_builder.emit(relax.op.subtract(grouped_x, mean_x))
+        square_x = self.block_builder.emit(relax.op.multiply(sub_x, sub_x))
+        sum_square_x = self.block_builder.emit(relax.op.sum(square_x, [2, 3, 
4], keepdims=True))
+        var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C // 
num_groups * H * W).value)
+        var_x_eps = self._call_binary_op(relax.op.add, var_x, eps)
+        std_x = self.block_builder.emit(relax.op.sqrt(var_x_eps))
+        norm_x = self.block_builder.emit(relax.op.divide(sub_x, std_x))
+
+        if affine:
+            weight = self.params[module.weight]
+            bias = self.params[module.bias]
+            weight_reshape = self.block_builder.emit(
+                relax.op.reshape(weight, (1, num_groups, C // num_groups, 1, 
1))
+            )
+            bias_reshape = self.block_builder.emit(
+                relax.op.reshape(bias, (1, num_groups, C // num_groups, 1, 1))
+            )
+            norm_x = self.block_builder.emit(relax.op.multiply(norm_x, 
weight_reshape))
+            norm_x = self.block_builder.emit(relax.op.add(norm_x, 
bias_reshape))
+        return self.block_builder.emit(relax.op.reshape(norm_x, (N, C, H, W)))
+
+    def _embedding(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        weight = self.params[module.weight]
+        x = self.block_builder.emit(relax.op.astype(x, "int32"))
+        return self.block_builder.emit(relax.op.take(weight, x, axis=0))
+
+    def _interpolate(self, node: fx.node.Node) -> relax.Var:
+        # torch.nn.functional.interpolate(
+        #   input, size=None, scale_factor=None, mode='nearest', 
align_corners=None,
+        #   recompute_scale_factor=None, antialias=False)
+        # (TODO) this is a temporary implementation for interpolate that only 
considers NCHW layout
+        # it basically replicates the implementation in 
tvm.relay.frontend.pytorch
+        data = self.env[node.args[0]]
+        size = node.kwargs["size"]
+        scale_factor = node.kwargs["scale_factor"]
+        method = node.kwargs["mode"]
+        align_corners = node.kwargs["align_corners"]
+        recompute_scale_factor = node.kwargs["recompute_scale_factor"]
+        antialias = node.kwargs["antialias"]
+
+        assert recompute_scale_factor is None
+        assert antialias is False
+
+        if size is None:
+            shape = self.shape_of(data)
+            assert isinstance(shape, relax.ShapeExpr)
+            size = tuple(int(shape[i].value * scale_factor) for i in range(2, 
len(shape)))
+
+        if method.startswith("nearest"):
+            method = "nearest_neighbor"
+        elif method[0:2] == "bi":
+            method = method[2:]
+
+        if method == "nearest_neighbor":
+            coord_trans = "asymmetric"
+        elif align_corners:
+            coord_trans = "align_corners"
+        else:
+            coord_trans = "half_pixel"
+
+        return self.block_builder.emit(
+            relax.op.image.resize2d(
+                data, size, layout="NCHW", method=method, 
coordinate_transformation_mode=coord_trans
+            )
+        )
+
+    ########## Others ##########
+
+    def _size(self, node: fx.node.Node) -> relax.Expr:
+        x = self.env[node.args[0]]
+        shape = self.shape_of(x)
+        if len(node.args) == 1:
+            assert isinstance(shape, relax.ShapeExpr)
+            return shape
+        assert len(node.args) == 2
+        idx = node.args[1]
+        return self.shape_of(x)[idx].value
+
+    def _getattr(self, node: fx.node.Node) -> relax.Var:
+        if isinstance(self.env[node.args[0]], relax.Expr):
+            if node.args[1] == "dtype":
+                return self.env[node.args[0]].struct_info.dtype
+            elif node.args[1] == "shape":
+                return self.shape_of(self.env[node.args[0]])
+        return getattr(self.env[node.args[0]], node.args[1])
+
+    def _getitem(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)):
+            return x[node.args[1]]
+        elif isinstance(x, relax.Var):
+            if isinstance(x.struct_info, relax.TupleStructInfo):
+                return self.block_builder.emit(relax.TupleGetItem(x, 
node.args[1]))
+
+            assert isinstance(x.struct_info, relax.TensorStructInfo)
+            begin = []
+            end = []
+            stride = []
+            axes = []
+            expand_dim = []
+            i = 0
+            shape = self.shape_of(x)
+            for index in node.args[1]:
+                if isinstance(index, int):
+                    begin.append(index)
+                    end.append(index + 1)
+                    stride.append(1)
+                    axes.append(i)
+                    i = i + 1
+                elif isinstance(index, slice):
+                    begin.append(0 if index.start is None else index.start)
+                    end.append(shape[i] if index.stop is None else index.stop)
+                    stride.append(1 if index.step is None else index.step)
+                    axes.append(i)
+                    i = i + 1
+                elif index is None:
+                    expand_dim.append(i)
+                    i = i + 1
+                else:
+                    raise ValueError("Unsupported index type: " + 
str(type(index)))
+            while i < len(shape):
+                begin.append(0)
+                end.append(shape[i])
+                axes.append(i)
+                i = i + 1
+            sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, 
begin, end, stride))
+            sliced_shape = list(self.shape_of(sliced))
+            for i in expand_dim:
+                sliced_shape.insert(i, 1)
+            return self.block_builder.emit(relax.op.reshape(sliced, 
sliced_shape))
+        else:
+            assert False
+
+    def create_convert_map(self):
+        from torch import nn
+        from torch import fx
+
+        self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node], 
relax.Var]] = {
+            # call_module
+            nn.Linear: self._linear,
+            nn.Conv2d: self._conv2d,
+            nn.MaxPool2d: self._max_pool2d,
+            nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True),
+            nn.Softmax: self._softmax,
+            nn.ReLU: lambda node: 
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
+            nn.ReLU6: lambda node: self.block_builder.emit(
+                relax.op.clip(self.env[node.args[0]], 0, 6)
+            ),
+            nn.SiLU: lambda node: 
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
+            nn.Flatten: self._flatten,
+            nn.BatchNorm2d: self._batch_norm_2d,
+            nn.LayerNorm: self._layer_norm,
+            nn.GroupNorm: self._group_norm,
+            nn.Dropout: lambda node: self.env[node.args[0]],
+            nn.modules.sparse.Embedding: self._embedding,
+            # call_function and call_method
+            "cos": self._cos,
+            "sin": self._sin,
+            "add": self._add,
+            "floordiv": self._floordiv,
+            "mul": self._mul,
+            "sub": self._sub,
+            "sqrt": self._sqrt,
+            "lt": self._lt,
+            "truediv": self._truediv,
+            "new_ones": self._new_ones,
+            "tril": self._tril,
+            "sum": self._sum,
+            "float": self._float,
+            "half": self._half,
+            "type": self._type,
+            "matmul": self._matmul,
+            "addmm": self._addmm,
+            "cat": self._cat,
+            "expand": self._expand,
+            "flatten": self._flatten,
+            "permute": self._permute,
+            "reshape": self._reshape,
+            "split": self._split,
+            "transpose": self._transpose,
+            "unsqueeze": lambda node: self.block_builder.emit(
+                relax.op.expand_dims(self.env[node.args[0]], node.args[1])
+            ),
+            "view": self._reshape,
+            "softmax": self._softmax,
+            "clamp": self._clamp,
+            "relu": lambda node: 
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
+            "gelu": lambda node: 
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])),
+            "interpolate": self._interpolate,
+            "size": self._size,
+            "getattr": self._getattr,
+            "getitem": self._getitem,
+            "contiguous": lambda node: self.env[node.args[0]],
+            "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
+        }
+
+    def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) -> 
tvm.IRModule:
+        """Convert a PyTorch FX GraphModule to a Relax program."""
+        from torch import fx
+
+        self.named_modules = dict(model.named_modules())
+
+        graph: fx.Graph = model.graph
+        # Create input variables.
+        inputs = list()
+        for idx, (shape, dtype) in enumerate(input_info):
+            inputs.append(
+                relax.Var(
+                    f"inp_{idx}", relax.TensorStructInfo(shape, 
self._convert_data_type(dtype))
+                )
+            )
+
+        # Initialize the block builder with a function and a dataflow block.
+        self.block_builder = relax.BlockBuilder()
+        with self.block_builder.function(name="main", params=inputs.copy()):
+            output = None
+            with self.block_builder.dataflow():
+                # Translate model parameters.
+                for _, param in model.named_parameters():
+                    shape = param.data.shape
+                    dtype = self._convert_data_type(str(param.data.dtype))
+                    if dtype in ("float32", "float16"):
+                        self.params[param] = relax.const(
+                            param.data.cpu().numpy(), 
relax.TensorStructInfo(shape, dtype)
+                        )
+                    else:
+                        raise ValueError("Unsupported data type for model 
parameters: %s" % dtype)
+                # Translate the model.
+                for node in graph.nodes:
+                    if node.op == "placeholder":
+                        assert len(inputs) > 0, "Provided inputs is less than 
actual inputs"
+                        self.env[node] = inputs.pop(0)
+                    elif node.op == "output":
+                        args = self.retrieve_args(node)
+                        output = self.block_builder.emit_output(args[0])
+                        break
+                    elif node.op == "get_attr":
+                        self.env[node] = TorchFXImporter._fetch_attr(model, 
node.target)
+                    elif node.op == "call_module":
+                        module = self.named_modules[node.target]
+                        assert (
+                            type(module) in self.convert_map
+                        ), f"Unsupported module type {type(module)}"
+                        self.env[node] = self.convert_map[type(module)](node)
+                    elif node.op == "call_function":
+                        func_name = node.name.rstrip("0123456789_")
+                        assert (
+                            func_name in self.convert_map
+                        ), f"Unsupported function type {func_name}"
+                        self.env[node] = self.convert_map[func_name](node)
+                    elif node.op == "call_method":
+                        assert (
+                            node.target in self.convert_map
+                        ), f"Unsupported function target {node.target}"
+                        self.env[node] = self.convert_map[node.target](node)
+                    else:
+                        raise ValueError(f"Unsupported op {node.op}")
+            assert output is not None
+            self.block_builder.emit_func_output(output)
+
+        return self.block_builder.get()
+
+
+def from_fx(model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule:
+    """Convert a PyTorch FX GraphModule to a Relax program
+
+    Parameters
+    ----------
+    model : fx.GraphModule
+        The PyTorch FX GraphModule to convert.
+
+    input_info : List[Tuple[Tuple[int], str]]
+        A list of shapes and data types of input tensors.
+
+    Returns
+    -------
+    module : tvm.IRModule
+        The converted Relax program.
+
+    Examples
+    --------
+    Users can use the FX tracer or dynamo.export() to extract
+    a fx.GraphModule from a PyTorch model. The following codes show
+    how to convert a PyTorch model to a Relax program.
+
+    .. code-block:: python
+
+        # Import the importer.
+        import numpy as np
+        import torch
+        from tvm.relax.frontend.torch_fx import from_fx
+        from torch import _dynamo as dynamo
+
+        # Define the module
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(in_features=10, out_features=7, 
bias=True)
+
+            def forward(self, input):
+                return self.linear(input)
+
+        # Instantiate the model and create the input info dict.
+        torch_model = MyModule()
+        input_info = [((128, 10), "float32")]
+        input_tensors = [
+            torch.astensor(np.random.randn(*shape).astype(dtype))
+            for shape, dtype in input_info
+        ]
+
+        # Use FX tracer to trace the PyTorch model.
+        graph_module = fx.symbolic_trace(torch_model)
+
+        # Use the dynamo.export() to export the PyTorch model to FX.
+        try:
+            graph_module = dynamo.export(torch_model, *input_tensors)
+        except:
+            raise RuntimeError("Failed to export the PyTorch model to FX.")
+
+        # Use the importer to import the PyTorch model to Relax.
+        mod: tvm.IRModule = from_fx(graph_module, input_info)
+
+        # Print out the imported model.
+        print(mod.script())
+
+    Notes
+    -----
+    For a given PyTorch model, to lookup the names of the model inputs in
+    FX, one can use
+
+    .. code-block:: python
+
+        fx.symbolic_trace(model).graph.print_tabular()
+
+    to print out the tabular representation of the PyTorch module, and then
+    check the placeholder rows in the beginning of the tabular.
+    """
+    return TorchFXImporter().from_fx(model, input_info)
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
new file mode 100644
index 0000000000..a5da15b76d
--- /dev/null
+++ b/python/tvm/relax/pipeline.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Pre-defined pipelines.
+
+oRelax enables flexible pipeline optimizations before min build.
+This namespace offers a pre-defined collection that can be used
+as it is or serves as a basis to do further composition.
+"""
+# pylint: disable=unused-argument
+import tvm
+from tvm import meta_schedule as ms
+from . import transform
+
+
[email protected]_pass(opt_level=0)
+def zero_pipeline(mod: tvm.ir.IRModule, ctx: tvm.transform.PassContext) -> 
tvm.ir.IRModule:
+    """Pipeline that applies pre-tuned logs.
+
+    Parameters
+    ----------
+    mod : tvm.ir.IRModule
+        Input IRModule.
+
+    ctx : tvm.transform.PassContext
+        The pass context
+
+    Returns
+    -------
+    mod: tvm.ir.IRModule
+        The result transformed module.
+    """
+    seq = tvm.transform.Sequential(
+        [
+            transform.LegalizeOps(),
+            transform.AnnotateTIROpPattern(),
+            transform.FoldConstant(),
+            transform.FuseOps(),
+            transform.FuseTIR(),
+        ]
+    )
+    mod = seq(mod)
+    if ms.Database.current():
+        mod = transform.MetaScheduleApplyDatabase()(mod)
+    return mod
+
+
+# global map of pre-built pipelines
+PIPELINE_MAP = {"zero": zero_pipeline}
+
+
+def get_pipeline(name: str = "zero") -> tvm.transform.Pass:
+    """Get pre-build pipeline by name
+
+    Parameters
+    ----------
+    name : Optional[str]
+        Name of the pipeline
+
+    Returns
+    -------
+    pipeline: tvm.transform.Pass
+       The transformation pipeline.
+    """
+
+    if name in PIPELINE_MAP:
+        return PIPELINE_MAP[name]
+    else:
+        raise ValueError(
+            f"Unknown pre-built pipeline {name}," f"candidates are 
{list(PIPELINE_MAP.keys())}"
+        )
diff --git a/tests/python/relax/test_frontend_dynamo.py 
b/tests/python/relax/test_frontend_dynamo.py
new file mode 100644
index 0000000000..370df2103d
--- /dev/null
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+pytest.importorskip("torch._dynamo")
+
+
+import tvm
+from tvm import relax, meta_schedule as ms, tir
+import tvm.testing
+import torch
+import torch._dynamo as dynamo
+from tvm.relax.frontend.torch import relax_dynamo
+from tvm.script.parser import relax as R, tir as T
+
+
+def test_relax_dynamo():
+    class Input1(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.lin = torch.nn.Linear(100, 10)
+
+        def forward(self, x):
+            return torch.nn.functional.relu(self.lin(x))
+
+    model = Input1()
+    ### construct the database
+    @tvm.script.ir_module
+    class Input1_ir:
+        @T.prim_func
+        def main(
+            inp_0: T.Buffer[(T.int64(10), T.int64(100)), "float32"],
+            param_0: T.Buffer[(T.int64(100), T.int64(10)), "float32"],
+            param_1: T.Buffer[T.int64(10), "float32"],
+            compute: T.Buffer[(T.int64(10), T.int64(10)), "float32"],
+        ):
+            # function attr dict
+            T.func_attr({"tir.noalias": True, "global_symbol": "main"})
+            # body
+            # with T.block("root")
+            matmul = T.alloc_buffer([T.int64(10), T.int64(10)], 
dtype="float32")
+            T_add = T.alloc_buffer([T.int64(10), T.int64(10)], dtype="float32")
+            for i0, i1, k in T.grid(T.int64(10), T.int64(10), T.int64(100)):
+                with T.block("matmul"):
+                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+                    T.reads(inp_0[v_i0, v_k], param_0[v_k, v_i1])
+                    T.writes(matmul[v_i0, v_i1])
+                    with T.init():
+                        matmul[v_i0, v_i1] = T.float32(0)
+                    matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + inp_0[v_i0, v_k] 
* param_0[v_k, v_i1]
+            for ax0, ax1 in T.grid(T.int64(10), T.int64(10)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(matmul[v_ax0, v_ax1], param_1[v_ax1])
+                    T.writes(T_add[v_ax0, v_ax1])
+                    T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + param_1[v_ax1]
+            for i0, i1 in T.grid(T.int64(10), T.int64(10)):
+                with T.block("compute"):
+                    v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(T_add[v_i0, v_i1])
+                    T.writes(compute[v_i0, v_i1])
+                    compute[v_i0, v_i1] = T.max(T_add[v_i0, v_i1], 
T.float32(0))
+
+    db = ms.Database.create("memory")
+    workload = db.commit_workload(Input1_ir)
+
+    sch = tir.Schedule(Input1_ir, debug_mask="all")
+    b0 = sch.get_block(name="matmul", func_name="main")
+    b1 = sch.get_block(name="T_add", func_name="main")
+    b2 = sch.get_block(name="root", func_name="main")
+    sch.compute_inline(block=b1)
+    sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", 
ann_val="SSRSRS")
+    l3, l4, l5 = sch.get_loops(block=b0)
+    v6, v7, v8, v9 = sch.sample_perfect_tile(
+        loop=l3, n=4, max_innermost_factor=64, decision=[1, 2, 5, 1]
+    )
+    l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9], 
preserve_unit_iters=True)
+    v14, v15, v16, v17 = sch.sample_perfect_tile(
+        loop=l4, n=4, max_innermost_factor=64, decision=[1, 1, 10, 1]
+    )
+    l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17], 
preserve_unit_iters=True)
+    v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64, 
decision=[100, 1])
+    l24, l25 = sch.split(loop=l5, factors=[v22, v23], preserve_unit_iters=True)
+    sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)
+    (b26,) = sch.get_consumers(block=b0)
+    sch.reverse_compute_at(block=b26, loop=l18, preserve_unit_loops=True, 
index=-1)
+    sch.annotate(block_or_loop=b2, ann_key="meta_schedule.parallel", 
ann_val=96)
+    sch.annotate(block_or_loop=b2, ann_key="meta_schedule.vectorize", 
ann_val=64)
+    v27 = sch.sample_categorical(
+        candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0
+    )
+    sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", 
ann_val=v27)
+
+    tuning_record = ms.database.TuningRecord(sch.trace, workload, 
run_secs=[0.0])
+    db.commit_tuning_record(tuning_record)
+    ### Optimize the model with tuned-log
+    with db:
+        opt_model = torch.compile(model, backend=relax_dynamo())
+    inp = torch.randn(10, 100)
+    tvm.testing.assert_allclose(
+        opt_model(inp).detach().numpy(), model(inp).detach().numpy(), 
rtol=1e-5, atol=1e-5
+    )
+
+
+def test_subgraph_capture():
+    import torch
+    from tvm.relax.frontend.torch.dynamo import dynamo_capture_subgraphs
+
+    class Input1(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.lin = torch.nn.Linear(100, 10)
+
+        def forward(self, x):
+            return torch.nn.functional.relu(self.lin(x))
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def subgraph_0(
+            inp_0: R.Tensor((10, 100), dtype="float32"),
+            w0: R.Tensor((10, 100), dtype="float32"),
+            w1: R.Tensor((10,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, 
axes=None)
+                lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(inp_0, lv, 
out_dtype="float32")
+                lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1)
+                lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    model = Input1()
+    mod = dynamo_capture_subgraphs(model, torch.randn(10, 100))
+    binding = {"w0": model.lin.weight.detach().numpy(), "w1": 
model.lin.bias.detach().numpy()}
+    binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+    expected = relax.transform.BindParams("subgraph_0", binding)(Expected1)
+    tvm.ir.assert_structural_equal(mod, expected)
+
+    def Input2(a, b):
+        x = a / (torch.sin(a) + 1)
+        if torch.sum(b) < 1:
+            b = b * -1
+        return x * b
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def subgraph_0(
+            inp_0: R.Tensor((10,), dtype="float32"), inp_1: R.Tensor((10,), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((), 
dtype="bool")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10,), dtype="float32") = R.sin(inp_0)
+                lv1: R.Tensor((10,), dtype="float32") = R.add(lv, R.const(1, 
"float32"))
+                lv2: R.Tensor((10,), dtype="float32") = R.divide(inp_0, lv1)
+                lv3: R.Tensor((), dtype="float32") = R.sum(inp_1, axis=None, 
keepdims=False)
+                lv4: R.Tensor((), dtype="bool") = R.less(lv3, R.const(1, 
"float32"))
+                gv: R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((), 
dtype="bool")) = (
+                    lv2,
+                    lv4,
+                )
+                R.output(gv)
+            return gv
+
+        @R.function
+        def subgraph_1(
+            inp_01: R.Tensor((10,), dtype="float32"), inp_11: R.Tensor((10,), 
dtype="float32")
+        ) -> R.Tuple(R.Tensor((10,), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_11, 
inp_01)
+                gv1: R.Tuple(R.Tensor((10,), dtype="float32")) = (lv5,)
+                R.output(gv1)
+            return gv1
+
+    mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10))
+    tvm.ir.assert_structural_equal(mod, Expected2)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
new file mode 100644
index 0000000000..9b35d34bd3
--- /dev/null
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -0,0 +1,1729 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+from tvm import relax
+import tvm.testing
+from tvm.script.parser import relax as R, tir as T
+
+
+def verify_model(torch_model, input_info, binding, expected):
+    from torch import fx
+    from tvm.relax.frontend.torch import from_fx
+
+    graph_model = fx.symbolic_trace(torch_model)
+    mod = from_fx(graph_model, input_info)
+    binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+    expected = relax.transform.BindParams("main", binding)(expected)
+    tvm.ir.assert_structural_equal(mod, expected)
+
+
[email protected]_gpu
+def test_conv():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    class Conv2D1(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((6, 3, 7, 7), dtype="float32"),
+            w2: R.Tensor((6,), dtype="float32"),
+        ) -> R.Tensor((1, 6, 4, 4), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
+                    input_1,
+                    w1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1])
+                lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
+                gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3
+                R.output(gv)
+            return gv
+
+    class Conv2D2(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 7, bias=False)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((6, 3, 7, 7), dtype="float32"),
+        ) -> R.Tensor((1, 6, 4, 4), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
+                    input_1,
+                    w1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    model = Conv2D1()
+    binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy()}
+    verify_model(model, input_info, binding, expected1)
+
+    model = Conv2D2()
+    binding = {"w1": model.conv.weight.numpy()}
+    verify_model(model, input_info, binding, expected2)
+
+
[email protected]_gpu
+def test_linear():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    # nn.Linear
+    class Dense1(Module):
+        def __init__(self):
+            super().__init__()
+            self.linear = torch.nn.Linear(10, 7, bias=True)
+
+        def forward(self, input):
+            return self.linear(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((7, 10), dtype="float32"),
+            w2: R.Tensor((1, 7), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 7), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, 
axes=None)
+                lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(
+                    input_1, lv, out_dtype="float32"
+                )
+                lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2)
+                gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv2
+                R.output(gv)
+            return gv
+
+    class Dense2(Module):
+        def __init__(self):
+            super().__init__()
+            self.linear = torch.nn.Linear(10, 7, bias=False)
+
+        def forward(self, input):
+            return self.linear(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((7, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 7), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, 
axes=None)
+                lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(
+                    input_1, lv, out_dtype="float32"
+                )
+                gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    model = Dense1()
+    binding = {"w1": model.linear.weight.numpy(), "w2": 
model.linear.bias.numpy()}
+    verify_model(model, input_info, binding, expected1)
+
+    model = Dense2()
+    binding = {"w1": model.linear.weight.numpy()}
+    verify_model(model, input_info, binding, expected2)
+
+    # matmul
+    class MatMul1(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x, y):
+            return torch.matmul(x, y)
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            input_1: R.Tensor((10, 10), dtype="float32"),
+            input_2: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.matmul(
+                    input_1, input_2, out_dtype="float32"
+                )
+                gv: R.Tensor((10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(
+        MatMul1(),
+        [([10, 10], "float32"), ([10, 10], "float32")],
+        {},
+        expected3,
+    )
+
+
[email protected]_gpu
+def test_relu():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+
+    class ReLU0(Module):
+        def __init__(self):
+            super().__init__()
+            self.relu = torch.nn.ReLU()
+
+        def forward(self, input):
+            return self.relu(input)
+
+    class ReLU1(Module):
+        def forward(self, input):
+            return torch.nn.functional.relu(input)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            input_1: R.Tensor((10, 10), dtype="float32")
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.nn.relu(input_1)
+                gv: R.Tensor((10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    input_info = [([10, 10], "float32")]
+    verify_model(ReLU0(), input_info, {}, expected)
+    verify_model(ReLU1(), input_info, {}, expected)
+
+
[email protected]_gpu
+def test_relu6():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+
+    class ReLU6(Module):
+        def __init__(self):
+            super().__init__()
+            self.relu6 = torch.nn.ReLU6()
+
+        def forward(self, input):
+            return self.relu6(input)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(input: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 
10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.clip(input, 0, 6)
+                gv: R.Tensor((10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    input_info = [([10, 10], "float32")]
+    verify_model(ReLU6(), input_info, {}, expected)
+
+
[email protected]_gpu
+def test_maxpool2d():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class MaxPool2d(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1])
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.max_pool2d(
+                    input_1,
+                    pool_size=[1, 1],
+                    strides=[1, 1],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    layout="NCHW",
+                    out_layout="NCHW",
+                )
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class MaxPool2d2(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3])
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 4, 4), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.max_pool2d(
+                    input_1,
+                    pool_size=[2, 2],
+                    strides=[2, 2],
+                    dilation=[2, 3],
+                    padding=[0, 0, 0, 0],
+                    layout="NCHW",
+                    out_layout="NCHW",
+                )
+                gv: R.Tensor((1, 3, 4, 4), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class MaxPool2d3(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, 
stride=2)
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 6, 6), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 6, 6), dtype="float32") = R.nn.max_pool2d(
+                    input_1,
+                    pool_size=[4, 4],
+                    strides=[2, 2],
+                    dilation=[1, 1],
+                    padding=[2, 2, 2, 2],
+                    layout="NCHW",
+                    out_layout="NCHW",
+                )
+                gv: R.Tensor((1, 3, 6, 6), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(MaxPool2d(), input_info, {}, expected1)
+    verify_model(MaxPool2d2(), input_info, {}, expected2)
+    verify_model(MaxPool2d3(), input_info, {}, expected3)
+
+
[email protected]_gpu
+def test_adaptive_avgpool2d():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class AdaptiveAvgPool2d0(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.AdaptiveAvgPool2d([10, 10])
+
+        def forward(self, input):
+            return self.pool(input)
+
+    class AdaptiveAvgPool2d1(Module):
+        def forward(self, input):
+            return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10])
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.adaptive_avg_pool2d(
+                    input_1, output_size=[10, 10], layout="NCHW", 
out_layout="NCHW"
+                )
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(AdaptiveAvgPool2d0(), input_info, {}, expected1)
+    verify_model(AdaptiveAvgPool2d1(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_flatten():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Flatten(Module):
+        def __init__(self):
+            super().__init__()
+            self.f = torch.nn.Flatten(2, -1)
+
+        def forward(self, input):
+            return self.f(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 100), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 100), dtype="float32") = 
R.reshape(input_1, (1, 3, 100))
+                gv: R.Tensor((1, 3, 100), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    # call_module
+    verify_model(Flatten(), input_info, {}, expected1)
+    # call_method
+    verify_model(torch.nn.Flatten(2, -1), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_batchnorm2d():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class BatchNorm2d(Module):
+        def __init__(self):
+            super().__init__()
+            self.bn = torch.nn.BatchNorm2d(3)
+
+        def forward(self, input):
+            return self.bn(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((3,), dtype="float32"),
+            w2: R.Tensor((3,), dtype="float32"),
+            w3: R.Tensor((3,), dtype="float32"),
+            w4: R.Tensor((3,), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                    R.Tensor((3,), dtype="float32"),
+                ) = R.nn.batch_norm(
+                    input_1,
+                    w1,
+                    w2,
+                    w3,
+                    w4,
+                    axis=1,
+                    epsilon=1e-05,
+                    center=True,
+                    scale=True,
+                )
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    model = BatchNorm2d()
+    binding = {
+        "w1": model.bn.weight.numpy(),
+        "w2": model.bn.bias.numpy(),
+        "w3": model.bn.running_mean.numpy(),
+        "w4": model.bn.running_var.numpy(),
+    }
+    verify_model(BatchNorm2d(), input_info, binding, expected1)
+
+
[email protected]_gpu
+def test_embedding():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([4], "int64")]
+
+    class Embedding(Module):
+        def __init__(self):
+            super().__init__()
+            self.embedding = torch.nn.Embedding(10, 3)
+
+        def forward(self, input):
+            return self.embedding(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3), 
dtype="float32")
+        ) -> R.Tensor((4, 3), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((4,), dtype="int32") = R.astype(input_1, 
dtype="int32")
+                lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0)
+                gv: R.Tensor((4, 3), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    model = Embedding()
+    binding = {"w1": model.embedding.weight.numpy()}
+    verify_model(model, input_info, binding, expected1)
+
+
[email protected]_gpu
+def test_dropout():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Dropout(Module):
+        def __init__(self):
+            super().__init__()
+            self.dropout = torch.nn.Dropout(0.5)
+
+        def forward(self, input):
+            return self.dropout(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = input_1
+                R.output(gv)
+            return gv
+
+    verify_model(Dropout(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_layernorm():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class LayerNorm(Module):
+        def __init__(self):
+            super().__init__()
+            self.ln = torch.nn.LayerNorm((10, 10))
+
+        def forward(self, input):
+            return self.ln(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((10, 10), dtype="float32"),
+            w2: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.layer_norm(
+                    input_1,
+                    w1,
+                    w2,
+                    axes=[-2, -1],
+                    epsilon=1e-05,
+                    center=True,
+                    scale=True,
+                )
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    model = LayerNorm()
+    binding = {
+        "w1": model.ln.weight.numpy(),
+        "w2": model.ln.bias.numpy(),
+    }
+    verify_model(LayerNorm(), input_info, binding, expected1)
+
+
[email protected]_gpu
+def test_silu():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class SiLU(Module):
+        def __init__(self):
+            super().__init__()
+            self.silu = torch.nn.SiLU()
+
+        def forward(self, input):
+            return self.silu(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.silu(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(SiLU(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_groupnorm():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class GroupNorm(Module):
+        def __init__(self):
+            super().__init__()
+            self.gn = torch.nn.GroupNorm(3, 3)
+
+        def forward(self, input):
+            return self.gn(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((3,), dtype="float32"),
+            w2: R.Tensor((3,), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.reshape(
+                    input_1, (1, 3, 1, 10, 10)
+                )
+                lv1: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.mean(
+                    lv, axis=[2, 3, 4], keepdims=True
+                )
+                lv2: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = 
R.subtract(lv, lv1)
+                lv3: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = 
R.multiply(lv2, lv2)
+                lv4: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sum(
+                    lv3, axis=[2, 3, 4], keepdims=True
+                )
+                lv5: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = 
R.divide(lv4, R.const(100.0))
+                lv6: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.add(lv5, 
R.const(1e-05))
+                lv7: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sqrt(lv6)
+                lv8: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = 
R.divide(lv2, lv7)
+                lv9: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = 
R.reshape(w1, (1, 3, 1, 1, 1))
+                lv10: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = 
R.reshape(w2, (1, 3, 1, 1, 1))
+                lv11: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = 
R.multiply(lv8, lv9)
+                lv12: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = 
R.add(lv11, lv10)
+                lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.reshape(lv12, (1, 3, 10, 10))
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv13
+                R.output(gv)
+            return gv
+
+    model = GroupNorm()
+    binding = {
+        "w1": model.gn.weight.numpy(),
+        "w2": model.gn.bias.numpy(),
+    }
+    verify_model(model, input_info, binding, expected1)
+
+
[email protected]_gpu
+def test_softmax():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Softmax(Module):
+        def __init__(self):
+            super().__init__()
+            self.sm = torch.nn.Softmax(dim=1)
+
+        def forward(self, input):
+            return self.sm(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.softmax(input_1, axis=1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Softmax(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_binary():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
+    input_info2 = [([1, 3, 10, 10], "float32")]
+    # Add
+    class Add1(Module):
+        def forward(self, lhs, rhs):
+            return lhs + rhs
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs, rhs)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class Add2(Module):
+        def forward(self, lhs):
+            return lhs + 1.0
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs_1, 
R.const(1.0))
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Add1(), input_info1, {}, expected1)
+    verify_model(Add2(), input_info2, {}, expected2)
+
+    # Sub
+    class Sub1(Module):
+        def forward(self, lhs, rhs):
+            return lhs - rhs
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.subtract(lhs_1, rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class Sub2(Module):
+        def forward(self, lhs):
+            return lhs - 1.0
+
+    @tvm.script.ir_module
+    class expected4:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.subtract(lhs_1, R.const(1.0))
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Sub1(), input_info1, {}, expected3)
+    verify_model(Sub2(), input_info2, {}, expected4)
+
+    # Mul
+    class Mul1(Module):
+        def forward(self, lhs, rhs):
+            return lhs * rhs
+
+    @tvm.script.ir_module
+    class expected5:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(lhs_1, rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class Mul2(Module):
+        def forward(self, lhs):
+            return lhs * 1.0
+
+    @tvm.script.ir_module
+    class expected6:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(lhs_1, R.const(1.0))
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Mul1(), input_info1, {}, expected5)
+    verify_model(Mul2(), input_info2, {}, expected6)
+
+    # True div
+    class TrueDiv1(Module):
+        def forward(self, lhs, rhs):
+            return lhs / rhs
+
+    @tvm.script.ir_module
+    class expected7:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.divide(lhs_1, rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class TrueDiv2(Module):
+        def forward(self, lhs):
+            return lhs / 1.0
+
+    @tvm.script.ir_module
+    class expected8:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.divide(lhs_1, R.const(1.0))
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(TrueDiv1(), input_info1, {}, expected7)
+    verify_model(TrueDiv2(), input_info2, {}, expected8)
+
+    # Floor div
+    class FloorDiv1(Module):
+        def forward(self, lhs, rhs):
+            return lhs // rhs
+
+    @tvm.script.ir_module
+    class expected9:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.floor_divide(lhs_1, rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class FloorDiv2(Module):
+        def forward(self, lhs):
+            return lhs // 1.0
+
+    @tvm.script.ir_module
+    class expected10:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.floor_divide(lhs_1, R.const(1.0))
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(FloorDiv1(), input_info1, {}, expected9)
+    verify_model(FloorDiv2(), input_info2, {}, expected10)
+
+    # LT
+    class LT1(Module):
+        def forward(self, lhs, rhs):
+            return lhs < rhs
+
+    @tvm.script.ir_module
+    class expected11:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1, 
rhs_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+                R.output(gv)
+            return gv
+
+    class LT2(Module):
+        def forward(self, lhs):
+            return lhs < 1.0
+
+    @tvm.script.ir_module
+    class expected12:
+        @R.function
+        def main(
+            lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1, 
R.const(1.0))
+                gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(LT1(), input_info1, {}, expected11)
+    verify_model(LT2(), input_info2, {}, expected12)
+
+
[email protected]_gpu
+def test_size():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Size(Module):
+        def forward(self, input):
+            return input.size()
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> 
R.Shape([1, 3, 10, 10]):
+            # block 0
+            with R.dataflow():
+                gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10])
+                R.output(gv)
+            return gv
+
+    verify_model(Size(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_unsqueeze():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Unsqueeze1(Module):
+        def forward(self, input):
+            return input.unsqueeze(1)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = 
R.expand_dims(input_1, 1)
+                gv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    class Unsqueeze2(Module):
+        def forward(self, input):
+            return input.unsqueeze(-1)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10, 1), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = 
R.expand_dims(input_1, -1)
+                gv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Unsqueeze1(), input_info, {}, expected1)
+    verify_model(Unsqueeze2(), input_info, {}, expected2)
+
+
[email protected]_gpu
+def test_getattr():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class GetAttr1(Module):
+        def forward(self, input):
+            return input.shape
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> 
R.Shape([1, 3, 10, 10]):
+            # block 0
+            with R.dataflow():
+                gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10])
+                R.output(gv)
+            return gv
+
+    verify_model(GetAttr1(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_getitem():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Slice1(Module):
+        def forward(self, x):
+            return x[0, 1::2, :, :3]
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 1, 10, 3), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 1, 10, 3), dtype="float32") = R.strided_slice(
+                    x,
+                    axes=[0, 1, 2, 3],
+                    begin=[0, 1, 0, 0],
+                    end=[1, T.int64(3), T.int64(10), 3],
+                    strides=[1, 2, 1, 1],
+                )
+                lv1: R.Tensor((1, 1, 10, 3), dtype="float32") = R.reshape(lv, 
(1, 1, 10, 3))
+                gv: R.Tensor((1, 1, 10, 3), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    verify_model(Slice1(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_unary():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    # sin
+    class Sin(Module):
+        def forward(self, input):
+            return torch.sin(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Sin(), input_info, {}, expected1)
+
+    # cos
+    class Cos(Module):
+        def forward(self, input):
+            return torch.cos(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Cos(), input_info, {}, expected2)
+
+    # sqrt
+    class Sqrt(Module):
+        def forward(self, input):
+            return torch.sqrt(input)
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Sqrt(), input_info, {}, expected3)
+
+
[email protected]_gpu
+def test_gelu():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Gelu(Module):
+        def forward(self, input):
+            return torch.nn.functional.gelu(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.gelu(input_1)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Gelu(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_clamp():
+    import torch
+    from torch import fx
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Clamp(Module):
+        def forward(self, input):
+            return torch.clamp(input, min=0.1, max=0.5)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.clip(input_1, 0.1, 0.5)
+                gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Clamp(), input_info, {}, expected1)
+
+    from tvm.relax.frontend.torch import from_fx
+
+    with pytest.raises(
+        ValueError, match="TVM only supports constant max value for 
torch.clamp/clip"
+    ):
+
+        class Clamp_Error(Module):
+            def forward(self, input):
+                return torch.clamp(input, min=0.5, max=None)
+
+        gm = fx.symbolic_trace(Clamp_Error())
+        from_fx(gm, input_info)
+
+    with pytest.raises(
+        ValueError, match="TVM only supports constant min value for 
torch.clamp/clip"
+    ):
+
+        class Clamp_Error(Module):
+            def forward(self, input):
+                return torch.clamp(input, min=input, max=input)
+
+        gm = fx.symbolic_trace(Clamp_Error())
+        from_fx(gm, input_info)
+
+
[email protected]_gpu
+def test_interpolate():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Interpolate(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(input, size=(5, 5))
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 5, 5), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 5), dtype="float32") = R.image.resize2d(
+                    input_1,
+                    (5, 5),
+                    roi=[0.000000, 0.000000, 0.000000, 0.000000],
+                    layout="NCHW",
+                    method="nearest_neighbor",
+                    coordinate_transformation_mode="asymmetric",
+                    rounding_method="round",
+                    cubic_alpha=-0.5,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="",
+                )
+                gv: R.Tensor((1, 3, 5, 5), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Interpolate(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_addmm():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [
+        ([10, 10], "float32"),
+        ([10, 10], "float32"),
+        ([10, 10], "float32"),
+    ]
+
+    class Addmm(Module):
+        def forward(self, x1, x2, x3):
+            return torch.addmm(x1, x2, x3)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x1: R.Tensor((10, 10), dtype="float32"),
+            x2: R.Tensor((10, 10), dtype="float32"),
+            x3: R.Tensor((10, 10), dtype="float32"),
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, 
out_dtype="float32")
+                lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv)
+                gv: R.Tensor((10, 10), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    verify_model(Addmm(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_split():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 3, 10, 10], "float32")]
+
+    class Split(Module):
+        def forward(self, input):
+            return torch.split(input, 1, dim=1)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(
+            R.Tensor((1, 1, 10, 10), dtype="float32"),
+            R.Tensor((1, 1, 10, 10), dtype="float32"),
+            R.Tensor((1, 1, 10, 10), dtype="float32"),
+        ):
+            # block 0
+            with R.dataflow():
+                lv: R.Tuple(
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                ) = R.split(input_1, indices_or_sections=3, axis=1)
+                gv: R.Tuple(
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                    R.Tensor((1, 1, 10, 10), dtype="float32"),
+                ) = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Split(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_tril():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([10, 10], "float32")]
+
+    class Tril(Module):
+        def forward(self, input):
+            return torch.tril(input, 1)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((10, 10), dtype="float32")
+        ) -> R.Tensor((10, 10), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1)
+                gv: R.Tensor((10, 10), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Tril(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_new_ones():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 2, 3], "float32")]
+
+    class NewOnes(Module):
+        def forward(self, x):
+            return x.new_ones(1, 2, 3)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 
3), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 2, 3), dtype="float32") = R.full(
+                    (1, 2, 3), R.const(1, "float32"), dtype="float32"
+                )
+                gv: R.Tensor((1, 2, 3), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(NewOnes(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_expand():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 2, 3, 4], "float32")]
+
+    class Expand(Module):
+        def forward(self, x):
+            return x.expand(4, 2, 3, 4)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tensor((4, 2, 3, 4), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((4, 2, 3, 4), dtype="float32") = 
R.broadcast_to(x, (4, 2, 3, 4))
+                gv: R.Tensor((4, 2, 3, 4), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Expand(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_reduce():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 2, 3, 4], "float32")]
+
+    # sum
+    class Sum(Module):
+        def forward(self, x):
+            return torch.sum(x, (2, 1))
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tensor((1, 4), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 
1], keepdims=False)
+                gv: R.Tensor((1, 4), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Sum(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_to():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 2, 3, 4], "float32")]
+
+    # float
+    class ToFloat(Module):
+        def forward(self, x):
+            return x.float()
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tensor((1, 2, 3, 4), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, 
dtype="float32")
+                gv: R.Tensor((1, 2, 3, 4), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(ToFloat(), input_info, {}, expected1)
+
+    # half
+    class ToHalf(Module):
+        def forward(self, x):
+            return x.half()
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tensor((1, 2, 3, 4), dtype="float16"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x, 
dtype="float16")
+                gv: R.Tensor((1, 2, 3, 4), dtype="float16") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(ToHalf(), input_info, {}, expected2)
+
+
[email protected]_gpu
+def test_permute():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 2, 3, 4], "float32")]
+
+    class Permute(Module):
+        def forward(self, x):
+            return x.permute(0, 3, 2, 1)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tensor((1, 4, 3, 2), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 4, 3, 2), dtype="float32") = 
R.permute_dims(x, axes=[0, 3, 2, 1])
+                gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Permute(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_reshape():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 2, 3, 4], "float32")]
+
+    class Reshape(Module):
+        def forward(self, x):
+            return x.reshape(2, 12)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 
12), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
+                gv: R.Tensor((2, 12), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Reshape(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_transpose():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 2, 3, 4], "float32")]
+
+    class Transpose(Module):
+        def forward(self, x):
+            return x.transpose(1, 3)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tensor((1, 4, 3, 2), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 4, 3, 2), dtype="float32") = 
R.permute_dims(x, axes=[0, 3, 2, 1])
+                gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Transpose(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_view():
+    import torch
+    from torch.nn import Module
+
+    torch.set_grad_enabled(False)
+    torch.random.manual_seed(0)
+
+    input_info = [([1, 2, 3, 4], "float32")]
+
+    class View(Module):
+        def forward(self, x):
+            return x.view(2, 12)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 
12), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
+                gv: R.Tensor((2, 12), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(View(), input_info, {}, expected1)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_pipeline.py 
b/tests/python/relax/test_pipeline.py
new file mode 100644
index 0000000000..6d6704ae97
--- /dev/null
+++ b/tests/python/relax/test_pipeline.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm
+from tvm import relax
+from tvm.script import relax as R
+
+
+def test_pipeline_compile():
+    pipeline = relax.get_pipeline()
+
+    @tvm.script.ir_module
+    class Mod:
+        @R.function
+        def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), 
"float32")):
+            lv0 = R.add(x, y)
+            return lv0
+
+    mod = Mod
+    mod = pipeline(mod)
+    target = tvm.target.Target("llvm", host="llvm")
+
+    ex = relax.vm.build(mod, target)
+    x_np = np.random.rand(3, 4).astype(np.float32)
+    y_np = np.random.rand(3, 4).astype(np.float32)
+    x = tvm.nd.array(x_np)
+    y = tvm.nd.array(y_np)
+
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    z = vm["main"](x, y)
+    tvm.testing.assert_allclose(z.numpy(), x_np + y_np, rtol=1e-7, atol=1e-7)

Reply via email to