This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3f2c91a652 [Relax][PyTorch] Add support for 
`torch.export.ExportedProgram` in Relax PyTorch Frontend (#17396)
3f2c91a652 is described below

commit 3f2c91a652a0a867703f2bc4176b80b2d1747c25
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Sep 27 10:00:17 2024 +0900

    [Relax][PyTorch] Add support for `torch.export.ExportedProgram` in Relax 
PyTorch Frontend (#17396)
    
    * introduce ExportedProgramImporter
    
    * address review comments
---
 python/tvm/relax/frontend/torch/__init__.py        |   1 +
 .../frontend/torch/base_fx_graph_translator.py     | 228 +++++++++
 .../frontend/torch/exported_program_translator.py  | 243 ++++++++++
 python/tvm/relax/frontend/torch/fx_translator.py   | 209 +-------
 .../relax/test_frontend_from_exported_program.py   | 535 +++++++++++++++++++++
 5 files changed, 1029 insertions(+), 187 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/__init__.py 
b/python/tvm/relax/frontend/torch/__init__.py
index 55da5a456d..36eac975df 100644
--- a/python/tvm/relax/frontend/torch/__init__.py
+++ b/python/tvm/relax/frontend/torch/__init__.py
@@ -17,5 +17,6 @@
 """
 PyTorch Frontends for constructing Relax programs, with the model importers
 """
+from .exported_program_translator import from_exported_program
 from .fx_translator import from_fx
 from .dynamo import relax_dynamo, dynamo_capture_subgraphs
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
new file mode 100644
index 0000000000..6a001b5a04
--- /dev/null
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -0,0 +1,228 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, inconsistent-return-statements, 
unidiomatic-typecheck
+# pylint: disable=import-outside-toplevel
+"""Base class for PyTorch FX Graph importer."""
+import abc
+from typing import Callable, Dict, Optional, Tuple, Union
+
+from tvm import relax
+
+
+class BaseFXGraphImporter(metaclass=abc.ABCMeta):
+    """Base class for FX Graph Importer."""
+
+    import torch  # type: ignore
+    from torch import fx
+
+    def __init__(self) -> None:
+        import torch  # type: ignore
+        from torch import fx
+
+        self.env: Dict[fx.Node, relax.Expr] = {}
+        self.params: Dict[torch.Tensor, relax.Expr] = {}
+        self.block_builder: relax.BlockBuilder = None
+        self.convert_map: Dict[
+            Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]
+        ] = self.create_convert_map()
+
+    ########## Utilities ##########
+
+    @staticmethod
+    def _convert_data_type(input_type: Union[str, torch.dtype], env: 
Optional[Dict] = None):
+        """converts the PyTorch scalar type input_type to a TVM dtype."""
+        import torch  # type: ignore
+
+        if env is not None and input_type in env:
+            input_type = env[input_type]
+
+        input_type = input_type.lower() if isinstance(input_type, str) else 
input_type
+        if input_type in ["float", "float32", "torch.float32", torch.float32]:
+            return "float32"
+        elif input_type in ["float16", "torch.float16", torch.float16]:
+            return "float16"
+        elif input_type in ["int64", "torch.int64", torch.int64]:
+            return "int64"
+        elif input_type in ["int32", "torch.int32", torch.int32]:
+            return "int32"
+        elif input_type in ["bool", "torch.bool", torch.bool]:
+            return "bool"
+        else:
+            raise NotImplementedError("input_type {} is not handled 
yet".format(input_type))
+
+    @staticmethod
+    def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var:
+        tensor = tensor.detach().cpu()
+        dtype = BaseFXGraphImporter._convert_data_type(str(tensor.data.dtype))
+        return relax.const(tensor.data.numpy(), dtype)
+
+    @staticmethod
+    def shape_of(tensor):
+        """Get the shape of a tensor."""
+        import torch  # type: ignore
+
+        if isinstance(tensor, relax.Expr):
+            if not isinstance(tensor.struct_info, relax.TensorStructInfo):
+                raise TypeError("The input Expr of shape_of should be a 
Tensor")
+            return tensor.struct_info.shape
+        elif isinstance(tensor, torch.Tensor):
+            return tensor.shape
+        raise ValueError("Unsupported type: {}".format(type(tensor)))
+
+    def retrieve_args(self, node: fx.Node):
+        return self._retrieve_args(node.args)
+
+    def _retrieve_args(self, node):
+        from torch import fx
+
+        if isinstance(node, fx.Node):
+            return self.env[node]
+        elif isinstance(node, tuple):
+            return tuple(self._retrieve_args(x) for x in node)
+        elif isinstance(node, list):
+            return [self._retrieve_args(x) for x in node]
+        elif isinstance(node, dict):
+            return {self._retrieve_args(k): self._retrieve_args(v) for k, v in 
node.items()}
+        else:
+            return node
+
+    ########## Unary Ops ##########
+
+    def _unary_op(self, op: Callable) -> Callable:
+        from torch import fx
+
+        def convert(node: fx.Node) -> relax.Var:
+            return self.block_builder.emit(op(self.env[node.args[0]]))
+
+        return convert
+
+    ########## Neural Network ##########
+
+    def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        output_size = node.args[1]
+        return self.block_builder.emit(
+            relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
+        )
+
+    def _conv2d_impl(
+        self,
+        x: relax.Expr,
+        weight: relax.Expr,
+        bias: Optional[relax.Expr],
+        strides: Optional[Tuple],
+        padding: Optional[Tuple],
+        dilation: Optional[Tuple],
+        groups: Optional[Tuple],
+    ):
+        conv2d = self.block_builder.emit(
+            relax.op.nn.conv2d(
+                x,
+                weight,
+                strides=strides,
+                padding=padding,
+                dilation=dilation,
+                groups=groups,
+                data_layout="NCHW",
+                kernel_layout="OIHW",
+                out_dtype="float32",
+            )
+        )
+
+        if bias is None:
+            return conv2d
+        assert len(self.shape_of(bias)) == 1
+        bias = relax.op.reshape(bias, (1, -1, 1, 1))
+        return self.block_builder.emit(relax.op.add(conv2d, bias))
+
+    def _conv2d(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        weight = args[1]
+        bias = args[2] if len(args) > 2 else None
+        stride = args[3] if len(args) > 3 else 1
+        padding = args[4] if len(args) > 4 else 0
+        dilation = args[5] if len(args) > 5 else 1
+        groups = args[6] if len(args) > 6 else 1
+        return self._conv2d_impl(
+            x,
+            weight,
+            bias=bias,
+            strides=stride,
+            padding=padding,
+            dilation=dilation,
+            groups=groups,
+        )
+
+    def _linear(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        weight = args[1]
+        bias = args[2] if len(args) > 2 else None
+        return self.block_builder.emit(relax.op.linear(x, weight, bias, 
"float32"))
+
+    def _max_pool2d_impl(
+        self,
+        x: relax.Expr,
+        kernel_size: Union[int, Tuple[int, int]] = (1, 1),
+        stride: Optional[Union[int, Tuple[int, int]]] = None,
+        padding: Optional[int] = 0,
+        dilation: Optional[int] = 1,
+        ceil_mode: Optional[bool] = False,
+    ) -> relax.Var:
+        stride = kernel_size if stride is None else stride
+        return self.block_builder.emit(
+            relax.op.nn.max_pool2d(
+                x,
+                pool_size=kernel_size,
+                strides=stride,
+                padding=padding,
+                dilation=dilation,
+                ceil_mode=ceil_mode,
+                layout="NCHW",
+            )
+        )
+
+    def _max_pool2d(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        kernel_size = args[1]
+        stride = args[2] if len(args) > 2 else None
+        padding = args[3] if len(args) > 3 else 0
+        dilation = args[4] if len(args) > 4 else 1
+        ceil_mode = args[5] if len(args) > 5 else False
+
+        return self._max_pool2d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
+
+    ########## Manipulation ##########
+
+    def _reshape(self, node: fx.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        args = self.retrieve_args(node)
+        x = args[0]
+        dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
+        return self.block_builder.emit(relax.op.reshape(x, dims))
+
+    ########## Others ##########
+
+    @abc.abstractmethod
+    def create_convert_map(
+        self,
+    ) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]:
+        """Create convert map"""
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
new file mode 100644
index 0000000000..9af422d1c3
--- /dev/null
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -0,0 +1,243 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, inconsistent-return-statements, 
unidiomatic-typecheck
+# pylint: disable=import-outside-toplevel
+"""PyTorch ExportedProgram of Relax."""
+from collections import ChainMap, OrderedDict
+from typing import Callable, Dict, List, Tuple
+
+import torch
+import tvm
+from tvm import relax
+
+from .base_fx_graph_translator import BaseFXGraphImporter
+
+
+class ExportedProgramImporter(BaseFXGraphImporter):
+    """An importer from ExportedProgram to Relax."""
+
+    from torch import fx
+
+    def create_input_vars(
+        self, exported_program: torch.export.ExportedProgram
+    ) -> Tuple[List[relax.Var], List[relax.Var]]:
+        """Create relax input vars."""
+        parameters_buffers_constants = []
+        user_inputs = []
+        for spec in exported_program.graph_signature.input_specs:
+            name_hint = spec.arg.name
+            if spec.kind is 
torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
+                shape = exported_program.tensor_constants[spec.target].shape
+                torch_dtype = 
exported_program.tensor_constants[spec.target].dtype
+            elif spec.kind is 
torch.export.graph_signature.InputKind.USER_INPUT:
+                for node in 
exported_program.graph.find_nodes(op="placeholder", target=spec.target):
+                    if node.name == name_hint:
+                        shape = node.meta["tensor_meta"].shape
+                        torch_dtype = node.meta["tensor_meta"].dtype
+                        break
+            else:
+                # PARAMETER or BUFFER
+                shape = exported_program.state_dict[spec.target].shape
+                torch_dtype = exported_program.state_dict[spec.target].dtype
+
+            dtype = self._convert_data_type(torch_dtype)
+            relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, 
dtype))
+            if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
+                user_inputs.append(relax_var)
+            else:
+                parameters_buffers_constants.append(relax_var)
+
+        return parameters_buffers_constants, user_inputs
+
+    def create_convert_map(
+        self,
+    ) -> Dict[str, Callable[[fx.Node], relax.Var]]:
+        return {
+            # unary
+            "dropout.default": lambda node: self.env[node.args[0]],
+            "relu.default": self._unary_op(relax.op.nn.relu),
+            # neural network
+            "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
+            "conv2d.default": self._conv2d,
+            "linear.default": self._linear,
+            "max_pool2d.default": self._max_pool2d,
+            # tensor manipulation
+            "view.default": self._reshape,
+        }
+
+    def from_exported_program(
+        self,
+        exported_program: torch.export.ExportedProgram,
+        keep_params_as_input: bool,
+        unwrap_unit_return_tuple: bool,
+        no_bind_return_tuple: bool,
+    ) -> tvm.IRModule:
+        """Convert a PyTorch ExportedProgram to a Relax program."""
+        from torch import fx  # type: ignore
+
+        # Create input variables.
+        parameter_buffer_constant_vars, user_input_vars = 
self.create_input_vars(exported_program)
+        inputs_vars = parameter_buffer_constant_vars + user_input_vars
+
+        # Initialize the block builder with a function and a dataflow block.
+        self.block_builder = relax.BlockBuilder()
+        func_name = "main"
+        func_attrs = {"num_input": len(user_input_vars)} if 
keep_params_as_input else None
+
+        nodes: List[fx.Node] = exported_program.graph.nodes
+        with self.block_builder.function(
+            name=func_name, params=inputs_vars.copy(), attrs=func_attrs
+        ):
+            output = None
+            with self.block_builder.dataflow():
+                # Translate the model.
+                for node in nodes:
+                    if node.op == "placeholder":
+                        if "grapharg" in node.meta and 
node.meta["grapharg"].fake_tensor is None:
+                            # Ignore sym input
+                            continue
+
+                        self.env[node] = inputs_vars.pop(0)
+                    elif node.op == "output":
+                        args = self.retrieve_args(node)
+                        assert len(args) == 1
+                        assert isinstance(args[0], (tuple, relax.Tuple))
+
+                        if unwrap_unit_return_tuple and len(args[0]) == 1:
+                            output = self.block_builder.emit_output(args[0][0])
+                        elif no_bind_return_tuple:
+                            output = []
+                            for ret in args[0]:
+                                
output.append(self.block_builder.emit_output(ret))
+                        else:
+                            output = self.block_builder.emit_output(args[0])
+                        break
+                    elif node.op == "get_attr":
+                        self.env[node] = 
getattr(exported_program.graph_module, node.target)
+                    elif node.op == "call_function":
+                        func_name = node.target.__name__
+                        assert (
+                            func_name in self.convert_map
+                        ), f"Unsupported function type {func_name}"
+                        self.env[node] = self.convert_map[func_name](node)
+                    else:
+                        raise ValueError(f"Unsupported op {node.op}")
+            assert output is not None
+            self.block_builder.emit_func_output(output)
+
+        to_bind_parameters = ChainMap(
+            OrderedDict(exported_program.named_buffers()), 
exported_program.constants
+        )
+        if not keep_params_as_input:
+            to_bind_parameters = to_bind_parameters.new_child(
+                OrderedDict(exported_program.named_parameters())
+            )
+
+        binding = {}
+        for tensor_name, tensor_value in to_bind_parameters.items():
+            # find relax var name from graph signature
+            for spec in exported_program.graph_signature.input_specs:
+                if tensor_name == spec.target:
+                    bind_name = spec.arg.name
+                    break
+            binding[bind_name] = tvm.nd.from_dlpack(tensor_value.detach())
+
+        mod = self.block_builder.get()
+        mod = relax.transform.BindParams("main", binding)(mod)
+
+        if keep_params_as_input:
+            parameters = dict(exported_program.named_parameters())
+            params = [tvm.nd.from_dlpack(p.detach()) for p in 
parameters.values()]
+            mod["main"] = mod["main"].with_attr("params", params)
+
+        return mod
+
+
+def from_exported_program(
+    exported_program: torch.export.ExportedProgram,
+    *,
+    keep_params_as_input: bool = False,
+    unwrap_unit_return_tuple: bool = False,
+    no_bind_return_tuple: bool = False,
+) -> tvm.IRModule:
+    """Convert a PyTorch ExportedProgram to a Relax program
+
+    Parameters
+    ----------
+    exported_program : torch.export.ExportedProgram
+        The PyTorch ExportedProgram to convert.
+
+    keep_params_as_input : bool
+        Whether to keep model parameters as input variables.
+
+    unwrap_unit_return_tuple : bool
+        A boolean flag indicating if to the return value when it is an unit 
tuple.
+        When the return value is not a unit tuple, no unwrap will take place.
+
+    no_bind_return_tuple : bool
+        A boolean flag indicating whether to bind the return tuple as a relax 
var.
+        If the flag is true and the return value is a tuple, it will not bind 
it to a var.
+
+    Returns
+    -------
+    output : tvm.IRModule
+        The import result IRModule, with the function "main" containing the
+        translated logic.
+
+    Examples
+    --------
+    Users can use the torch.export.export() to extract a 
torch.export.ExportedProgram
+    from a PyTorch model. The following codes show how to convert a PyTorch 
model to
+    a Relax program.
+
+    .. code-block:: python
+
+        # Import the importer.
+        import tvm
+        from tvm.relax.frontend.torch import from_exported_program
+        import torch
+        from torch.export import export
+
+        # Define the module
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(in_features=10, out_features=7, 
bias=True)
+
+            def forward(self, input):
+                return self.linear(input)
+
+        # Instantiate the model and create the input info dict.
+        torch_model = MyModule()
+
+        # Use torch.export.export() to convert the PyTorch model into 
ExportedProgram.
+        example_args = (torch.rand(128, 10, dtype=torch.float32),)
+        exported_program = export(torch_model, args=example_args)
+
+        # Use the importer to import the ExportedProgram to Relax.
+        mod: tvm.IRModule = from_exported_program(exported_program)
+    """
+    # decompose into Core ATen operators
+    exported_program.run_decompositions()
+
+    return ExportedProgramImporter().from_exported_program(
+        exported_program,
+        keep_params_as_input,
+        unwrap_unit_return_tuple,
+        no_bind_return_tuple,
+    )
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 27da69dbb1..ec53cf23ed 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -24,8 +24,10 @@ from functools import partial, reduce
 import tvm
 from tvm import relax
 
+from .base_fx_graph_translator import BaseFXGraphImporter
 
-class TorchFXImporter:
+
+class TorchFXImporter(BaseFXGraphImporter):
     """An importer from PyTorch FX to Relax."""
 
     import torch  # type: ignore
@@ -33,15 +35,12 @@ class TorchFXImporter:
 
     def __init__(self) -> None:
         import torch  # type: ignore
-        from torch import fx
 
-        self.env: Dict[fx.Node, relax.Expr] = {}
-        self.params: Dict[torch.Tensor, relax.Expr] = {}
+        super().__init__()
         self.named_modules: Dict[str, torch.Module] = None
-        self.block_builder: relax.BlockBuilder = None
-        self.create_convert_map()
 
     ########## Utilities ##########
+
     def _fetch_attr(self, model, target: str):
         import torch  # type: ignore
 
@@ -58,77 +57,11 @@ class TorchFXImporter:
             # If so, return the parameter instead.
             if attr_itr in self.params:
                 return self.params[attr_itr]
-            return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr)
+            return self._convert_torch_tensor_to_relax(attr_itr)
         return attr_itr
 
-    @staticmethod
-    def _convert_data_type(input_type: Union[str, torch.dtype], env: 
Optional[Dict] = None):
-        """converts the PyTorch scalar type input_type to a TVM dtype."""
-        import torch  # type: ignore
-
-        if env is not None and input_type in env:
-            input_type = env[input_type]
-
-        input_type = input_type.lower() if isinstance(input_type, str) else 
input_type
-        if input_type in ["float", "float32", "torch.float32", torch.float32]:
-            return "float32"
-        elif input_type in ["float16", "torch.float16", torch.float16]:
-            return "float16"
-        elif input_type in ["int64", "torch.int64", torch.int64]:
-            return "int64"
-        elif input_type in ["int32", "torch.int32", torch.int32]:
-            return "int32"
-        elif input_type in ["bool", "torch.bool", torch.bool]:
-            return "bool"
-        else:
-            raise NotImplementedError("input_type {} is not handled 
yet".format(input_type))
-
-    @staticmethod
-    def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var:
-        tensor = tensor.detach().cpu()
-        dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype))
-        return relax.const(tensor.data.numpy(), dtype)
-
-    @staticmethod
-    def shape_of(tensor):
-        """Get the shape of a tensor."""
-        import torch  # type: ignore
-
-        if isinstance(tensor, relax.Expr):
-            if not isinstance(tensor.struct_info, relax.TensorStructInfo):
-                raise TypeError("The input Expr of shape_of should be a 
Tensor")
-            return tensor.struct_info.shape
-        elif isinstance(tensor, torch.Tensor):
-            return tensor.shape
-        raise ValueError("Unsupported type: {}".format(type(tensor)))
-
-    def retrieve_args(self, node):
-        return self._retrieve_args(node.args)
-
-    def _retrieve_args(self, node):
-        from torch import fx
-
-        if isinstance(node, fx.Node):
-            return self.env[node]
-        elif isinstance(node, tuple):
-            return tuple(self._retrieve_args(x) for x in node)
-        elif isinstance(node, list):
-            return [self._retrieve_args(x) for x in node]
-        elif isinstance(node, dict):
-            return {self._retrieve_args(k): self._retrieve_args(v) for k, v in 
node.items()}
-        else:
-            return node
-
     ########## Unary Ops ##########
 
-    def _unary_op(self, op: Callable) -> Callable:
-        from torch import fx
-
-        def convert(node: fx.Node) -> relax.Var:
-            return self.block_builder.emit(op(self.env[node.args[0]]))
-
-        return convert
-
     def _clamp(self, node: fx.Node) -> relax.Expr:
         args = self.retrieve_args(node)
         a_min = args[1] if len(args) > 1 else node.kwargs["min"]
@@ -272,13 +205,6 @@ class TorchFXImporter:
 
     ########## Neural Network ##########
 
-    def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
-        x = self.env[node.args[0]]
-        output_size = node.args[1]
-        return self.block_builder.emit(
-            relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
-        )
-
     def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var:
 
         module = self.named_modules[node.target]
@@ -590,55 +516,6 @@ class TorchFXImporter:
             groups=module.groups,
         )
 
-    def _conv2d_impl(
-        self,
-        x: relax.Expr,
-        weight: relax.Expr,
-        bias: Optional[relax.Expr],
-        strides: Optional[Tuple],
-        padding: Optional[Tuple],
-        dilation: Optional[Tuple],
-        groups: Optional[Tuple],
-    ):
-        conv2d = self.block_builder.emit(
-            relax.op.nn.conv2d(
-                x,
-                weight,
-                strides=strides,
-                padding=padding,
-                dilation=dilation,
-                groups=groups,
-                data_layout="NCHW",
-                kernel_layout="OIHW",
-                out_dtype="float32",
-            )
-        )
-
-        if bias is None:
-            return conv2d
-        assert len(self.shape_of(bias)) == 1
-        bias = relax.op.reshape(bias, (1, -1, 1, 1))
-        return self.block_builder.emit(relax.op.add(conv2d, bias))
-
-    def _conv2d(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        weight = args[1]
-        bias = args[2] if len(args) > 2 else None
-        stride = args[3] if len(args) > 3 else 1
-        padding = args[4] if len(args) > 4 else 0
-        dilation = args[5] if len(args) > 5 else 1
-        groups = args[6] if len(args) > 6 else 1
-        return self._conv2d_impl(
-            x,
-            weight,
-            bias=bias,
-            strides=stride,
-            padding=padding,
-            dilation=dilation,
-            groups=groups,
-        )
-
     def _conv2d_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -940,13 +817,6 @@ class TorchFXImporter:
         eps = module.eps
         return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape)
 
-    def _linear(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        weight = args[1]
-        bias = args[2] if len(args) > 2 else None
-        return self.block_builder.emit(relax.op.linear(x, weight, bias, 
"float32"))
-
     def _linear_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -954,39 +824,6 @@ class TorchFXImporter:
         bias = self.params.get(module.bias, None)
         return self.block_builder.emit(relax.op.linear(x, weight, bias, 
"float32"))
 
-    def _max_pool2d_impl(
-        self,
-        x: relax.Expr,
-        kernel_size: Union[int, Tuple[int, int]] = (1, 1),
-        stride: Optional[Union[int, Tuple[int, int]]] = None,
-        padding: Optional[int] = 0,
-        dilation: Optional[int] = 1,
-        ceil_mode: Optional[bool] = False,
-    ) -> relax.Var:
-        stride = kernel_size if stride is None else stride
-        return self.block_builder.emit(
-            relax.op.nn.max_pool2d(
-                x,
-                pool_size=kernel_size,
-                strides=stride,
-                padding=padding,
-                dilation=dilation,
-                ceil_mode=ceil_mode,
-                layout="NCHW",
-            )
-        )
-
-    def _max_pool2d(self, node: fx.Node) -> relax.Var:
-        args = self.retrieve_args(node)
-        x = args[0]
-        kernel_size = args[1]
-        stride = args[2] if len(args) > 2 else None
-        padding = args[3] if len(args) > 3 else 0
-        dilation = args[4] if len(args) > 4 else 1
-        ceil_mode = args[5] if len(args) > 5 else False
-
-        return self._max_pool2d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
-
     def _max_pool2d_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         module = self.named_modules[node.target]
@@ -1138,14 +975,6 @@ class TorchFXImporter:
         dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
         return self.block_builder.emit(relax.op.tile(x, dims))
 
-    def _reshape(self, node: fx.Node) -> relax.Var:
-        import torch  # type: ignore
-
-        args = self.retrieve_args(node)
-        x = args[0]
-        dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else 
args[1:]
-        return self.block_builder.emit(relax.op.reshape(x, dims))
-
     def _size(self, node: fx.Node) -> relax.Expr:
         x = self.env[node.args[0]]
         shape = self.shape_of(x)
@@ -1448,12 +1277,23 @@ class TorchFXImporter:
         idx = node.args[1]
         return self.block_builder.emit(relax.const(shape[idx].value, "int32"))
 
-    def create_convert_map(self):
+    def create_input_vars(self, input_info: List[Tuple[Tuple[int], str]]) -> 
List[relax.Var]:
+        inputs = list()
+        for idx, (shape, dtype) in enumerate(input_info):
+            inputs.append(
+                relax.Var(
+                    f"inp_{idx}", relax.TensorStructInfo(shape, 
self._convert_data_type(dtype))
+                )
+            )
+        return inputs
+
+    def create_convert_map(
+        self,
+    ) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]:
         import operator
         from torch import nn
-        from torch import fx
 
-        self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.Node], 
relax.Var]] = {
+        return {
             ## call_module
             # unary
             nn.Dropout: lambda node: self.env[node.args[0]],
@@ -1638,14 +1478,9 @@ class TorchFXImporter:
         self.named_modules = dict(model.named_modules())
 
         graph: fx.Graph = model.graph
+
         # Create input variables.
-        inputs = list()
-        for idx, (shape, dtype) in enumerate(input_info):
-            inputs.append(
-                relax.Var(
-                    f"inp_{idx}", relax.TensorStructInfo(shape, 
self._convert_data_type(dtype))
-                )
-            )
+        inputs = self.create_input_vars(input_info)
 
         # Initialize the block builder with a function and a dataflow block.
         func_name = "main"
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
new file mode 100644
index 0000000000..112390fe60
--- /dev/null
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -0,0 +1,535 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import torch
+from torch.nn import Module
+from torch.export import export
+
+import tvm
+from tvm import relax
+import tvm.testing
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+from tvm.relax.frontend.torch import from_exported_program
+
+
+def verify_model(torch_model, example_args, binding, expected):
+    exported_program = export(torch_model, args=example_args)
+    mod = from_exported_program(exported_program)
+
+    binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+    expected = relax.transform.BindParams("main", binding)(expected)
+    tvm.ir.assert_structural_equal(mod, expected)
+
+
+def test_unary():
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+    # dropout
+    class Dropout1(Module):
+        def __init__(self):
+            super().__init__()
+            self.dropout = torch.nn.Dropout(0.5)
+
+        def forward(self, input):
+            return self.dropout(input)
+
+    class Dropout2(Module):
+        def forward(self, input):
+            return torch.dropout(input, 0.5, train=True)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = 
(input_1,)
+                R.output(gv)
+            return gv
+
+    verify_model(Dropout1(), example_args, {}, expected1)
+    verify_model(Dropout2(), example_args, {}, expected1)
+
+    # relu
+    class ReLU0(Module):
+        def __init__(self):
+            super().__init__()
+            self.relu = torch.nn.ReLU()
+
+        def forward(self, input):
+            return self.relu(input)
+
+    class ReLU1(Module):
+        def forward(self, input):
+            return torch.nn.functional.relu(input)
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.relu(input_1)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    verify_model(ReLU0(), example_args, {}, expected)
+    verify_model(ReLU1(), example_args, {}, expected)
+
+
+def test_adaptive_avgpool2d():
+    class AdaptiveAvgPool2d0(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.AdaptiveAvgPool2d([10, 10])
+
+        def forward(self, input):
+            return self.pool(input)
+
+    class AdaptiveAvgPool2d1(Module):
+        def forward(self, input):
+            return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10])
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.adaptive_avg_pool2d(
+                    input_1, output_size=[10, 10], layout="NCHW", 
out_layout="NCHW"
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1)
+    verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)
+
+
+def test_conv2d():
+    class Conv2D1(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    class Conv2D1Func(Module):
+        def __init__(self):
+            super().__init__()
+            self.weight = torch.randn(size=[6, 3, 7, 7])
+            self.bias = torch.randn(size=[6])
+
+        def forward(self, input):
+            return torch.nn.functional.conv2d(input, self.weight, self.bias)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((6, 3, 7, 7), dtype="float32"),
+            w2: R.Tensor((6,), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
+                    input_1,
+                    w1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1])
+                lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
+                gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    class Conv2D2(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 7, bias=False)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            w1: R.Tensor((6, 3, 7, 7), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
+                    input_1,
+                    w1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+    model = Conv2D1()
+    binding = {"w1": model.conv.weight.detach().numpy(), "w2": 
model.conv.bias.detach().numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+    model = Conv2D1Func()
+    binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+    model = Conv2D2()
+    binding = {"w1": model.conv.weight.detach().numpy()}
+    verify_model(model, example_args, binding, expected2)
+
+
+def test_linear():
+    class Dense1(Module):
+        def __init__(self):
+            super().__init__()
+            self.linear = torch.nn.Linear(10, 7, bias=True)
+
+        def forward(self, input):
+            return self.linear(input)
+
+    class Dense1Func(Module):
+        def __init__(self):
+            super().__init__()
+            self.weight = torch.randn(size=[7, 10])
+            self.bias = torch.randn(size=[7])
+
+        def forward(self, input):
+            return torch.nn.functional.linear(input, self.weight, self.bias)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            w1: R.Tensor((7, 10), dtype="float32"),
+            w2: R.Tensor((7,), dtype="float32"),
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, 
axes=None)
+                lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(
+                    input_1, lv, out_dtype="float32"
+                )
+                lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv2,)
+                R.output(gv)
+            return gv
+
+    class Dense2(Module):
+        def __init__(self):
+            super().__init__()
+            self.linear = torch.nn.Linear(10, 7, bias=False)
+
+        def forward(self, input):
+            return self.linear(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            w1: R.Tensor((7, 10), dtype="float32"),
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, 
axes=None)
+                lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(
+                    input_1, lv, out_dtype="float32"
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+    model = Dense1()
+    binding = {"w1": model.linear.weight.detach().numpy(), "w2": 
model.linear.bias.detach().numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+    model = Dense1Func()
+    binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
+    verify_model(model, example_args, binding, expected1)
+
+    model = Dense2()
+    binding = {"w1": model.linear.weight.detach().numpy()}
+    verify_model(model, example_args, binding, expected2)
+
+
+def test_maxpool2d():
+    class MaxPool2d(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1])
+
+        def forward(self, input):
+            return self.pool(input)
+
+    class MaxPool2d_functional(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, input):
+            return torch.nn.functional.max_pool2d(input, kernel_size=[1, 1])
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.max_pool2d(
+                    input_1,
+                    pool_size=[1, 1],
+                    strides=[1, 1],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    layout="NCHW",
+                    out_layout="NCHW",
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class MaxPool2d2(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3])
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.max_pool2d(
+                    input_1,
+                    pool_size=[2, 2],
+                    strides=[2, 2],
+                    dilation=[2, 3],
+                    padding=[0, 0, 0, 0],
+                    layout="NCHW",
+                    out_layout="NCHW",
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    class MaxPool2d3(Module):
+        def __init__(self):
+            super().__init__()
+            self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, 
stride=2)
+
+        def forward(self, input):
+            return self.pool(input)
+
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 6, 6), dtype="float32") = R.nn.max_pool2d(
+                    input_1,
+                    pool_size=[4, 4],
+                    strides=[2, 2],
+                    dilation=[1, 1],
+                    padding=[2, 2, 2, 2],
+                    layout="NCHW",
+                    out_layout="NCHW",
+                )
+                gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    verify_model(MaxPool2d(), example_args, {}, expected1)
+    verify_model(MaxPool2d_functional(), example_args, {}, expected1)
+    verify_model(MaxPool2d2(), example_args, {}, expected2)
+    verify_model(MaxPool2d3(), example_args, {}, expected3)
+
+
+def test_view():
+    class View(Module):
+        def forward(self, x):
+            return x.view(2, 12)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 3, 4), dtype="float32")
+        ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
+                gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+    verify_model(View(), example_args, {}, expected1)
+
+
+def test_keep_params():
+    class Conv2D1(Module):
+        def __init__(self):
+            super().__init__()
+            self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
+
+        def forward(self, input):
+            return self.conv(input)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(
+            conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"),
+            conv_bias: R.Tensor((6,), dtype="float32"),
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
+            R.func_attr({"num_input": 1})
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
+                    input_1,
+                    conv_weight,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="float32",
+                )
+                lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = 
R.reshape(conv_bias, [1, 6, 1, 1])
+                lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
+                gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,)
+                R.output(gv)
+            return gv
+
+    from tvm.relax.frontend import detach_params
+
+    example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+    model = Conv2D1()
+    exported_program = torch.export.export(model, example_args)
+    mod = from_exported_program(exported_program, keep_params_as_input=True)
+    mod, params = detach_params(mod)
+    tvm.ir.assert_structural_equal(mod, expected1)
+    func = mod["main"]
+    params = params["main"]
+
+    assert len(params) == len(func.params) - 1
+    for param_var, param_ndarray in zip(func.params[:-1], params):
+        assert tuple(x.value for x in param_var.struct_info.shape.values) == 
param_ndarray.shape
+        assert param_var.struct_info.dtype == param_ndarray.dtype
+
+    tvm.testing.assert_allclose(params[0].numpy(), 
model.conv.weight.detach().detach().numpy())
+    tvm.testing.assert_allclose(params[1].numpy(), 
model.conv.bias.detach().detach().numpy())
+
+
+def test_unwrap_unit_return_tuple():
+    class Identity(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return (x,)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32")
+        ) -> R.Tensor((256, 256), dtype="float32"):
+            with R.dataflow():
+                gv: R.Tensor((256, 256), dtype="float32") = inp_0
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(256, 256, dtype=torch.float32),)
+    exported_program = export(Identity(), args=example_args)
+    mod = from_exported_program(exported_program, 
unwrap_unit_return_tuple=True)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_no_bind_return_tuple():
+    class Identity(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x, y):
+            return (x, y)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((256, 256), dtype="float32"),
+            inp_1: R.Tensor((256, 256), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 
256), dtype="float32")):
+            with R.dataflow():
+                gv: R.Tensor((256, 256), dtype="float32") = inp_0
+                gv1: R.Tensor((256, 256), dtype="float32") = inp_1
+                R.output(gv, gv1)
+            return (gv, gv1)
+
+    example_args = (
+        torch.randn(256, 256, dtype=torch.float32),
+        torch.randn(256, 256, dtype=torch.float32),
+    )
+    exported_program = export(Identity(), args=example_args)
+    mod = from_exported_program(exported_program, no_bind_return_tuple=True)
+    tvm.ir.assert_structural_equal(mod, Expected)

Reply via email to