This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3f2c91a652 [Relax][PyTorch] Add support for
`torch.export.ExportedProgram` in Relax PyTorch Frontend (#17396)
3f2c91a652 is described below
commit 3f2c91a652a0a867703f2bc4176b80b2d1747c25
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Sep 27 10:00:17 2024 +0900
[Relax][PyTorch] Add support for `torch.export.ExportedProgram` in Relax
PyTorch Frontend (#17396)
* introduce ExportedProgramImporter
* address review comments
---
python/tvm/relax/frontend/torch/__init__.py | 1 +
.../frontend/torch/base_fx_graph_translator.py | 228 +++++++++
.../frontend/torch/exported_program_translator.py | 243 ++++++++++
python/tvm/relax/frontend/torch/fx_translator.py | 209 +-------
.../relax/test_frontend_from_exported_program.py | 535 +++++++++++++++++++++
5 files changed, 1029 insertions(+), 187 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/__init__.py
b/python/tvm/relax/frontend/torch/__init__.py
index 55da5a456d..36eac975df 100644
--- a/python/tvm/relax/frontend/torch/__init__.py
+++ b/python/tvm/relax/frontend/torch/__init__.py
@@ -17,5 +17,6 @@
"""
PyTorch Frontends for constructing Relax programs, with the model importers
"""
+from .exported_program_translator import from_exported_program
from .fx_translator import from_fx
from .dynamo import relax_dynamo, dynamo_capture_subgraphs
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
new file mode 100644
index 0000000000..6a001b5a04
--- /dev/null
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -0,0 +1,228 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, inconsistent-return-statements,
unidiomatic-typecheck
+# pylint: disable=import-outside-toplevel
+"""Base class for PyTorch FX Graph importer."""
+import abc
+from typing import Callable, Dict, Optional, Tuple, Union
+
+from tvm import relax
+
+
+class BaseFXGraphImporter(metaclass=abc.ABCMeta):
+ """Base class for FX Graph Importer."""
+
+ import torch # type: ignore
+ from torch import fx
+
+ def __init__(self) -> None:
+ import torch # type: ignore
+ from torch import fx
+
+ self.env: Dict[fx.Node, relax.Expr] = {}
+ self.params: Dict[torch.Tensor, relax.Expr] = {}
+ self.block_builder: relax.BlockBuilder = None
+ self.convert_map: Dict[
+ Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]
+ ] = self.create_convert_map()
+
+ ########## Utilities ##########
+
+ @staticmethod
+ def _convert_data_type(input_type: Union[str, torch.dtype], env:
Optional[Dict] = None):
+ """converts the PyTorch scalar type input_type to a TVM dtype."""
+ import torch # type: ignore
+
+ if env is not None and input_type in env:
+ input_type = env[input_type]
+
+ input_type = input_type.lower() if isinstance(input_type, str) else
input_type
+ if input_type in ["float", "float32", "torch.float32", torch.float32]:
+ return "float32"
+ elif input_type in ["float16", "torch.float16", torch.float16]:
+ return "float16"
+ elif input_type in ["int64", "torch.int64", torch.int64]:
+ return "int64"
+ elif input_type in ["int32", "torch.int32", torch.int32]:
+ return "int32"
+ elif input_type in ["bool", "torch.bool", torch.bool]:
+ return "bool"
+ else:
+ raise NotImplementedError("input_type {} is not handled
yet".format(input_type))
+
+ @staticmethod
+ def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var:
+ tensor = tensor.detach().cpu()
+ dtype = BaseFXGraphImporter._convert_data_type(str(tensor.data.dtype))
+ return relax.const(tensor.data.numpy(), dtype)
+
+ @staticmethod
+ def shape_of(tensor):
+ """Get the shape of a tensor."""
+ import torch # type: ignore
+
+ if isinstance(tensor, relax.Expr):
+ if not isinstance(tensor.struct_info, relax.TensorStructInfo):
+ raise TypeError("The input Expr of shape_of should be a
Tensor")
+ return tensor.struct_info.shape
+ elif isinstance(tensor, torch.Tensor):
+ return tensor.shape
+ raise ValueError("Unsupported type: {}".format(type(tensor)))
+
+ def retrieve_args(self, node: fx.Node):
+ return self._retrieve_args(node.args)
+
+ def _retrieve_args(self, node):
+ from torch import fx
+
+ if isinstance(node, fx.Node):
+ return self.env[node]
+ elif isinstance(node, tuple):
+ return tuple(self._retrieve_args(x) for x in node)
+ elif isinstance(node, list):
+ return [self._retrieve_args(x) for x in node]
+ elif isinstance(node, dict):
+ return {self._retrieve_args(k): self._retrieve_args(v) for k, v in
node.items()}
+ else:
+ return node
+
+ ########## Unary Ops ##########
+
+ def _unary_op(self, op: Callable) -> Callable:
+ from torch import fx
+
+ def convert(node: fx.Node) -> relax.Var:
+ return self.block_builder.emit(op(self.env[node.args[0]]))
+
+ return convert
+
+ ########## Neural Network ##########
+
+ def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ output_size = node.args[1]
+ return self.block_builder.emit(
+ relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
+ )
+
+ def _conv2d_impl(
+ self,
+ x: relax.Expr,
+ weight: relax.Expr,
+ bias: Optional[relax.Expr],
+ strides: Optional[Tuple],
+ padding: Optional[Tuple],
+ dilation: Optional[Tuple],
+ groups: Optional[Tuple],
+ ):
+ conv2d = self.block_builder.emit(
+ relax.op.nn.conv2d(
+ x,
+ weight,
+ strides=strides,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_dtype="float32",
+ )
+ )
+
+ if bias is None:
+ return conv2d
+ assert len(self.shape_of(bias)) == 1
+ bias = relax.op.reshape(bias, (1, -1, 1, 1))
+ return self.block_builder.emit(relax.op.add(conv2d, bias))
+
+ def _conv2d(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ weight = args[1]
+ bias = args[2] if len(args) > 2 else None
+ stride = args[3] if len(args) > 3 else 1
+ padding = args[4] if len(args) > 4 else 0
+ dilation = args[5] if len(args) > 5 else 1
+ groups = args[6] if len(args) > 6 else 1
+ return self._conv2d_impl(
+ x,
+ weight,
+ bias=bias,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+
+ def _linear(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ weight = args[1]
+ bias = args[2] if len(args) > 2 else None
+ return self.block_builder.emit(relax.op.linear(x, weight, bias,
"float32"))
+
+ def _max_pool2d_impl(
+ self,
+ x: relax.Expr,
+ kernel_size: Union[int, Tuple[int, int]] = (1, 1),
+ stride: Optional[Union[int, Tuple[int, int]]] = None,
+ padding: Optional[int] = 0,
+ dilation: Optional[int] = 1,
+ ceil_mode: Optional[bool] = False,
+ ) -> relax.Var:
+ stride = kernel_size if stride is None else stride
+ return self.block_builder.emit(
+ relax.op.nn.max_pool2d(
+ x,
+ pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ ceil_mode=ceil_mode,
+ layout="NCHW",
+ )
+ )
+
+ def _max_pool2d(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ kernel_size = args[1]
+ stride = args[2] if len(args) > 2 else None
+ padding = args[3] if len(args) > 3 else 0
+ dilation = args[4] if len(args) > 4 else 1
+ ceil_mode = args[5] if len(args) > 5 else False
+
+ return self._max_pool2d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+
+ ########## Manipulation ##########
+
+ def _reshape(self, node: fx.Node) -> relax.Var:
+ import torch # type: ignore
+
+ args = self.retrieve_args(node)
+ x = args[0]
+ dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
+ return self.block_builder.emit(relax.op.reshape(x, dims))
+
+ ########## Others ##########
+
+ @abc.abstractmethod
+ def create_convert_map(
+ self,
+ ) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]:
+ """Create convert map"""
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
new file mode 100644
index 0000000000..9af422d1c3
--- /dev/null
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -0,0 +1,243 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, inconsistent-return-statements,
unidiomatic-typecheck
+# pylint: disable=import-outside-toplevel
+"""PyTorch ExportedProgram of Relax."""
+from collections import ChainMap, OrderedDict
+from typing import Callable, Dict, List, Tuple
+
+import torch
+import tvm
+from tvm import relax
+
+from .base_fx_graph_translator import BaseFXGraphImporter
+
+
+class ExportedProgramImporter(BaseFXGraphImporter):
+ """An importer from ExportedProgram to Relax."""
+
+ from torch import fx
+
+ def create_input_vars(
+ self, exported_program: torch.export.ExportedProgram
+ ) -> Tuple[List[relax.Var], List[relax.Var]]:
+ """Create relax input vars."""
+ parameters_buffers_constants = []
+ user_inputs = []
+ for spec in exported_program.graph_signature.input_specs:
+ name_hint = spec.arg.name
+ if spec.kind is
torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
+ shape = exported_program.tensor_constants[spec.target].shape
+ torch_dtype =
exported_program.tensor_constants[spec.target].dtype
+ elif spec.kind is
torch.export.graph_signature.InputKind.USER_INPUT:
+ for node in
exported_program.graph.find_nodes(op="placeholder", target=spec.target):
+ if node.name == name_hint:
+ shape = node.meta["tensor_meta"].shape
+ torch_dtype = node.meta["tensor_meta"].dtype
+ break
+ else:
+ # PARAMETER or BUFFER
+ shape = exported_program.state_dict[spec.target].shape
+ torch_dtype = exported_program.state_dict[spec.target].dtype
+
+ dtype = self._convert_data_type(torch_dtype)
+ relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape,
dtype))
+ if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
+ user_inputs.append(relax_var)
+ else:
+ parameters_buffers_constants.append(relax_var)
+
+ return parameters_buffers_constants, user_inputs
+
+ def create_convert_map(
+ self,
+ ) -> Dict[str, Callable[[fx.Node], relax.Var]]:
+ return {
+ # unary
+ "dropout.default": lambda node: self.env[node.args[0]],
+ "relu.default": self._unary_op(relax.op.nn.relu),
+ # neural network
+ "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
+ "conv2d.default": self._conv2d,
+ "linear.default": self._linear,
+ "max_pool2d.default": self._max_pool2d,
+ # tensor manipulation
+ "view.default": self._reshape,
+ }
+
+ def from_exported_program(
+ self,
+ exported_program: torch.export.ExportedProgram,
+ keep_params_as_input: bool,
+ unwrap_unit_return_tuple: bool,
+ no_bind_return_tuple: bool,
+ ) -> tvm.IRModule:
+ """Convert a PyTorch ExportedProgram to a Relax program."""
+ from torch import fx # type: ignore
+
+ # Create input variables.
+ parameter_buffer_constant_vars, user_input_vars =
self.create_input_vars(exported_program)
+ inputs_vars = parameter_buffer_constant_vars + user_input_vars
+
+ # Initialize the block builder with a function and a dataflow block.
+ self.block_builder = relax.BlockBuilder()
+ func_name = "main"
+ func_attrs = {"num_input": len(user_input_vars)} if
keep_params_as_input else None
+
+ nodes: List[fx.Node] = exported_program.graph.nodes
+ with self.block_builder.function(
+ name=func_name, params=inputs_vars.copy(), attrs=func_attrs
+ ):
+ output = None
+ with self.block_builder.dataflow():
+ # Translate the model.
+ for node in nodes:
+ if node.op == "placeholder":
+ if "grapharg" in node.meta and
node.meta["grapharg"].fake_tensor is None:
+ # Ignore sym input
+ continue
+
+ self.env[node] = inputs_vars.pop(0)
+ elif node.op == "output":
+ args = self.retrieve_args(node)
+ assert len(args) == 1
+ assert isinstance(args[0], (tuple, relax.Tuple))
+
+ if unwrap_unit_return_tuple and len(args[0]) == 1:
+ output = self.block_builder.emit_output(args[0][0])
+ elif no_bind_return_tuple:
+ output = []
+ for ret in args[0]:
+
output.append(self.block_builder.emit_output(ret))
+ else:
+ output = self.block_builder.emit_output(args[0])
+ break
+ elif node.op == "get_attr":
+ self.env[node] =
getattr(exported_program.graph_module, node.target)
+ elif node.op == "call_function":
+ func_name = node.target.__name__
+ assert (
+ func_name in self.convert_map
+ ), f"Unsupported function type {func_name}"
+ self.env[node] = self.convert_map[func_name](node)
+ else:
+ raise ValueError(f"Unsupported op {node.op}")
+ assert output is not None
+ self.block_builder.emit_func_output(output)
+
+ to_bind_parameters = ChainMap(
+ OrderedDict(exported_program.named_buffers()),
exported_program.constants
+ )
+ if not keep_params_as_input:
+ to_bind_parameters = to_bind_parameters.new_child(
+ OrderedDict(exported_program.named_parameters())
+ )
+
+ binding = {}
+ for tensor_name, tensor_value in to_bind_parameters.items():
+ # find relax var name from graph signature
+ for spec in exported_program.graph_signature.input_specs:
+ if tensor_name == spec.target:
+ bind_name = spec.arg.name
+ break
+ binding[bind_name] = tvm.nd.from_dlpack(tensor_value.detach())
+
+ mod = self.block_builder.get()
+ mod = relax.transform.BindParams("main", binding)(mod)
+
+ if keep_params_as_input:
+ parameters = dict(exported_program.named_parameters())
+ params = [tvm.nd.from_dlpack(p.detach()) for p in
parameters.values()]
+ mod["main"] = mod["main"].with_attr("params", params)
+
+ return mod
+
+
+def from_exported_program(
+ exported_program: torch.export.ExportedProgram,
+ *,
+ keep_params_as_input: bool = False,
+ unwrap_unit_return_tuple: bool = False,
+ no_bind_return_tuple: bool = False,
+) -> tvm.IRModule:
+ """Convert a PyTorch ExportedProgram to a Relax program
+
+ Parameters
+ ----------
+ exported_program : torch.export.ExportedProgram
+ The PyTorch ExportedProgram to convert.
+
+ keep_params_as_input : bool
+ Whether to keep model parameters as input variables.
+
+ unwrap_unit_return_tuple : bool
+ A boolean flag indicating if to the return value when it is an unit
tuple.
+ When the return value is not a unit tuple, no unwrap will take place.
+
+ no_bind_return_tuple : bool
+ A boolean flag indicating whether to bind the return tuple as a relax
var.
+ If the flag is true and the return value is a tuple, it will not bind
it to a var.
+
+ Returns
+ -------
+ output : tvm.IRModule
+ The import result IRModule, with the function "main" containing the
+ translated logic.
+
+ Examples
+ --------
+ Users can use the torch.export.export() to extract a
torch.export.ExportedProgram
+ from a PyTorch model. The following codes show how to convert a PyTorch
model to
+ a Relax program.
+
+ .. code-block:: python
+
+ # Import the importer.
+ import tvm
+ from tvm.relax.frontend.torch import from_exported_program
+ import torch
+ from torch.export import export
+
+ # Define the module
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(in_features=10, out_features=7,
bias=True)
+
+ def forward(self, input):
+ return self.linear(input)
+
+ # Instantiate the model and create the input info dict.
+ torch_model = MyModule()
+
+ # Use torch.export.export() to convert the PyTorch model into
ExportedProgram.
+ example_args = (torch.rand(128, 10, dtype=torch.float32),)
+ exported_program = export(torch_model, args=example_args)
+
+ # Use the importer to import the ExportedProgram to Relax.
+ mod: tvm.IRModule = from_exported_program(exported_program)
+ """
+ # decompose into Core ATen operators
+ exported_program.run_decompositions()
+
+ return ExportedProgramImporter().from_exported_program(
+ exported_program,
+ keep_params_as_input,
+ unwrap_unit_return_tuple,
+ no_bind_return_tuple,
+ )
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 27da69dbb1..ec53cf23ed 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -24,8 +24,10 @@ from functools import partial, reduce
import tvm
from tvm import relax
+from .base_fx_graph_translator import BaseFXGraphImporter
-class TorchFXImporter:
+
+class TorchFXImporter(BaseFXGraphImporter):
"""An importer from PyTorch FX to Relax."""
import torch # type: ignore
@@ -33,15 +35,12 @@ class TorchFXImporter:
def __init__(self) -> None:
import torch # type: ignore
- from torch import fx
- self.env: Dict[fx.Node, relax.Expr] = {}
- self.params: Dict[torch.Tensor, relax.Expr] = {}
+ super().__init__()
self.named_modules: Dict[str, torch.Module] = None
- self.block_builder: relax.BlockBuilder = None
- self.create_convert_map()
########## Utilities ##########
+
def _fetch_attr(self, model, target: str):
import torch # type: ignore
@@ -58,77 +57,11 @@ class TorchFXImporter:
# If so, return the parameter instead.
if attr_itr in self.params:
return self.params[attr_itr]
- return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr)
+ return self._convert_torch_tensor_to_relax(attr_itr)
return attr_itr
- @staticmethod
- def _convert_data_type(input_type: Union[str, torch.dtype], env:
Optional[Dict] = None):
- """converts the PyTorch scalar type input_type to a TVM dtype."""
- import torch # type: ignore
-
- if env is not None and input_type in env:
- input_type = env[input_type]
-
- input_type = input_type.lower() if isinstance(input_type, str) else
input_type
- if input_type in ["float", "float32", "torch.float32", torch.float32]:
- return "float32"
- elif input_type in ["float16", "torch.float16", torch.float16]:
- return "float16"
- elif input_type in ["int64", "torch.int64", torch.int64]:
- return "int64"
- elif input_type in ["int32", "torch.int32", torch.int32]:
- return "int32"
- elif input_type in ["bool", "torch.bool", torch.bool]:
- return "bool"
- else:
- raise NotImplementedError("input_type {} is not handled
yet".format(input_type))
-
- @staticmethod
- def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var:
- tensor = tensor.detach().cpu()
- dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype))
- return relax.const(tensor.data.numpy(), dtype)
-
- @staticmethod
- def shape_of(tensor):
- """Get the shape of a tensor."""
- import torch # type: ignore
-
- if isinstance(tensor, relax.Expr):
- if not isinstance(tensor.struct_info, relax.TensorStructInfo):
- raise TypeError("The input Expr of shape_of should be a
Tensor")
- return tensor.struct_info.shape
- elif isinstance(tensor, torch.Tensor):
- return tensor.shape
- raise ValueError("Unsupported type: {}".format(type(tensor)))
-
- def retrieve_args(self, node):
- return self._retrieve_args(node.args)
-
- def _retrieve_args(self, node):
- from torch import fx
-
- if isinstance(node, fx.Node):
- return self.env[node]
- elif isinstance(node, tuple):
- return tuple(self._retrieve_args(x) for x in node)
- elif isinstance(node, list):
- return [self._retrieve_args(x) for x in node]
- elif isinstance(node, dict):
- return {self._retrieve_args(k): self._retrieve_args(v) for k, v in
node.items()}
- else:
- return node
-
########## Unary Ops ##########
- def _unary_op(self, op: Callable) -> Callable:
- from torch import fx
-
- def convert(node: fx.Node) -> relax.Var:
- return self.block_builder.emit(op(self.env[node.args[0]]))
-
- return convert
-
def _clamp(self, node: fx.Node) -> relax.Expr:
args = self.retrieve_args(node)
a_min = args[1] if len(args) > 1 else node.kwargs["min"]
@@ -272,13 +205,6 @@ class TorchFXImporter:
########## Neural Network ##########
- def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
- x = self.env[node.args[0]]
- output_size = node.args[1]
- return self.block_builder.emit(
- relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
- )
-
def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var:
module = self.named_modules[node.target]
@@ -590,55 +516,6 @@ class TorchFXImporter:
groups=module.groups,
)
- def _conv2d_impl(
- self,
- x: relax.Expr,
- weight: relax.Expr,
- bias: Optional[relax.Expr],
- strides: Optional[Tuple],
- padding: Optional[Tuple],
- dilation: Optional[Tuple],
- groups: Optional[Tuple],
- ):
- conv2d = self.block_builder.emit(
- relax.op.nn.conv2d(
- x,
- weight,
- strides=strides,
- padding=padding,
- dilation=dilation,
- groups=groups,
- data_layout="NCHW",
- kernel_layout="OIHW",
- out_dtype="float32",
- )
- )
-
- if bias is None:
- return conv2d
- assert len(self.shape_of(bias)) == 1
- bias = relax.op.reshape(bias, (1, -1, 1, 1))
- return self.block_builder.emit(relax.op.add(conv2d, bias))
-
- def _conv2d(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- weight = args[1]
- bias = args[2] if len(args) > 2 else None
- stride = args[3] if len(args) > 3 else 1
- padding = args[4] if len(args) > 4 else 0
- dilation = args[5] if len(args) > 5 else 1
- groups = args[6] if len(args) > 6 else 1
- return self._conv2d_impl(
- x,
- weight,
- bias=bias,
- strides=stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- )
-
def _conv2d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -940,13 +817,6 @@ class TorchFXImporter:
eps = module.eps
return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape)
- def _linear(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- weight = args[1]
- bias = args[2] if len(args) > 2 else None
- return self.block_builder.emit(relax.op.linear(x, weight, bias,
"float32"))
-
def _linear_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -954,39 +824,6 @@ class TorchFXImporter:
bias = self.params.get(module.bias, None)
return self.block_builder.emit(relax.op.linear(x, weight, bias,
"float32"))
- def _max_pool2d_impl(
- self,
- x: relax.Expr,
- kernel_size: Union[int, Tuple[int, int]] = (1, 1),
- stride: Optional[Union[int, Tuple[int, int]]] = None,
- padding: Optional[int] = 0,
- dilation: Optional[int] = 1,
- ceil_mode: Optional[bool] = False,
- ) -> relax.Var:
- stride = kernel_size if stride is None else stride
- return self.block_builder.emit(
- relax.op.nn.max_pool2d(
- x,
- pool_size=kernel_size,
- strides=stride,
- padding=padding,
- dilation=dilation,
- ceil_mode=ceil_mode,
- layout="NCHW",
- )
- )
-
- def _max_pool2d(self, node: fx.Node) -> relax.Var:
- args = self.retrieve_args(node)
- x = args[0]
- kernel_size = args[1]
- stride = args[2] if len(args) > 2 else None
- padding = args[3] if len(args) > 3 else 0
- dilation = args[4] if len(args) > 4 else 1
- ceil_mode = args[5] if len(args) > 5 else False
-
- return self._max_pool2d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
-
def _max_pool2d_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
module = self.named_modules[node.target]
@@ -1138,14 +975,6 @@ class TorchFXImporter:
dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
return self.block_builder.emit(relax.op.tile(x, dims))
- def _reshape(self, node: fx.Node) -> relax.Var:
- import torch # type: ignore
-
- args = self.retrieve_args(node)
- x = args[0]
- dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else
args[1:]
- return self.block_builder.emit(relax.op.reshape(x, dims))
-
def _size(self, node: fx.Node) -> relax.Expr:
x = self.env[node.args[0]]
shape = self.shape_of(x)
@@ -1448,12 +1277,23 @@ class TorchFXImporter:
idx = node.args[1]
return self.block_builder.emit(relax.const(shape[idx].value, "int32"))
- def create_convert_map(self):
+ def create_input_vars(self, input_info: List[Tuple[Tuple[int], str]]) ->
List[relax.Var]:
+ inputs = list()
+ for idx, (shape, dtype) in enumerate(input_info):
+ inputs.append(
+ relax.Var(
+ f"inp_{idx}", relax.TensorStructInfo(shape,
self._convert_data_type(dtype))
+ )
+ )
+ return inputs
+
+ def create_convert_map(
+ self,
+ ) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]:
import operator
from torch import nn
- from torch import fx
- self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.Node],
relax.Var]] = {
+ return {
## call_module
# unary
nn.Dropout: lambda node: self.env[node.args[0]],
@@ -1638,14 +1478,9 @@ class TorchFXImporter:
self.named_modules = dict(model.named_modules())
graph: fx.Graph = model.graph
+
# Create input variables.
- inputs = list()
- for idx, (shape, dtype) in enumerate(input_info):
- inputs.append(
- relax.Var(
- f"inp_{idx}", relax.TensorStructInfo(shape,
self._convert_data_type(dtype))
- )
- )
+ inputs = self.create_input_vars(input_info)
# Initialize the block builder with a function and a dataflow block.
func_name = "main"
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
new file mode 100644
index 0000000000..112390fe60
--- /dev/null
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -0,0 +1,535 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import torch
+from torch.nn import Module
+from torch.export import export
+
+import tvm
+from tvm import relax
+import tvm.testing
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+from tvm.relax.frontend.torch import from_exported_program
+
+
+def verify_model(torch_model, example_args, binding, expected):
+ exported_program = export(torch_model, args=example_args)
+ mod = from_exported_program(exported_program)
+
+ binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+ expected = relax.transform.BindParams("main", binding)(expected)
+ tvm.ir.assert_structural_equal(mod, expected)
+
+
+def test_unary():
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+ # dropout
+ class Dropout1(Module):
+ def __init__(self):
+ super().__init__()
+ self.dropout = torch.nn.Dropout(0.5)
+
+ def forward(self, input):
+ return self.dropout(input)
+
+ class Dropout2(Module):
+ def forward(self, input):
+ return torch.dropout(input, 0.5, train=True)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) =
(input_1,)
+ R.output(gv)
+ return gv
+
+ verify_model(Dropout1(), example_args, {}, expected1)
+ verify_model(Dropout2(), example_args, {}, expected1)
+
+ # relu
+ class ReLU0(Module):
+ def __init__(self):
+ super().__init__()
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, input):
+ return self.relu(input)
+
+ class ReLU1(Module):
+ def forward(self, input):
+ return torch.nn.functional.relu(input)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.relu(input_1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(ReLU0(), example_args, {}, expected)
+ verify_model(ReLU1(), example_args, {}, expected)
+
+
+def test_adaptive_avgpool2d():
+ class AdaptiveAvgPool2d0(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.AdaptiveAvgPool2d([10, 10])
+
+ def forward(self, input):
+ return self.pool(input)
+
+ class AdaptiveAvgPool2d1(Module):
+ def forward(self, input):
+ return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10])
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.adaptive_avg_pool2d(
+ input_1, output_size=[10, 10], layout="NCHW",
out_layout="NCHW"
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1)
+ verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1)
+
+
+def test_conv2d():
+ class Conv2D1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ class Conv2D1Func(Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = torch.randn(size=[6, 3, 7, 7])
+ self.bias = torch.randn(size=[6])
+
+ def forward(self, input):
+ return torch.nn.functional.conv2d(input, self.weight, self.bias)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((6, 3, 7, 7), dtype="float32"),
+ w2: R.Tensor((6,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
+ input_1,
+ w1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1])
+ lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
+ gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ class Conv2D2(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 7, bias=False)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((6, 3, 7, 7), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
+ input_1,
+ w1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+ model = Conv2D1()
+ binding = {"w1": model.conv.weight.detach().numpy(), "w2":
model.conv.bias.detach().numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+ model = Conv2D1Func()
+ binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+ model = Conv2D2()
+ binding = {"w1": model.conv.weight.detach().numpy()}
+ verify_model(model, example_args, binding, expected2)
+
+
+def test_linear():
+ class Dense1(Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(10, 7, bias=True)
+
+ def forward(self, input):
+ return self.linear(input)
+
+ class Dense1Func(Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = torch.randn(size=[7, 10])
+ self.bias = torch.randn(size=[7])
+
+ def forward(self, input):
+ return torch.nn.functional.linear(input, self.weight, self.bias)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ w1: R.Tensor((7, 10), dtype="float32"),
+ w2: R.Tensor((7,), dtype="float32"),
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1,
axes=None)
+ lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(
+ input_1, lv, out_dtype="float32"
+ )
+ lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv2,)
+ R.output(gv)
+ return gv
+
+ class Dense2(Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(10, 7, bias=False)
+
+ def forward(self, input):
+ return self.linear(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ w1: R.Tensor((7, 10), dtype="float32"),
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1,
axes=None)
+ lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(
+ input_1, lv, out_dtype="float32"
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+ model = Dense1()
+ binding = {"w1": model.linear.weight.detach().numpy(), "w2":
model.linear.bias.detach().numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+ model = Dense1Func()
+ binding = {"w1": model.weight.detach().numpy(), "w2":
model.bias.detach().numpy()}
+ verify_model(model, example_args, binding, expected1)
+
+ model = Dense2()
+ binding = {"w1": model.linear.weight.detach().numpy()}
+ verify_model(model, example_args, binding, expected2)
+
+
+def test_maxpool2d():
+ class MaxPool2d(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1])
+
+ def forward(self, input):
+ return self.pool(input)
+
+ class MaxPool2d_functional(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return torch.nn.functional.max_pool2d(input, kernel_size=[1, 1])
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.max_pool2d(
+ input_1,
+ pool_size=[1, 1],
+ strides=[1, 1],
+ dilation=[1, 1],
+ padding=[0, 0, 0, 0],
+ layout="NCHW",
+ out_layout="NCHW",
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class MaxPool2d2(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3])
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.max_pool2d(
+ input_1,
+ pool_size=[2, 2],
+ strides=[2, 2],
+ dilation=[2, 3],
+ padding=[0, 0, 0, 0],
+ layout="NCHW",
+ out_layout="NCHW",
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ class MaxPool2d3(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2,
stride=2)
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 6, 6), dtype="float32") = R.nn.max_pool2d(
+ input_1,
+ pool_size=[4, 4],
+ strides=[2, 2],
+ dilation=[1, 1],
+ padding=[2, 2, 2, 2],
+ layout="NCHW",
+ out_layout="NCHW",
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ verify_model(MaxPool2d(), example_args, {}, expected1)
+ verify_model(MaxPool2d_functional(), example_args, {}, expected1)
+ verify_model(MaxPool2d2(), example_args, {}, expected2)
+ verify_model(MaxPool2d3(), example_args, {}, expected3)
+
+
+def test_view():
+ class View(Module):
+ def forward(self, x):
+ return x.view(2, 12)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
+ gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
+ verify_model(View(), example_args, {}, expected1)
+
+
+def test_keep_params():
+ class Conv2D1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"),
+ conv_bias: R.Tensor((6,), dtype="float32"),
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")):
+ R.func_attr({"num_input": 1})
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
+ input_1,
+ conv_weight,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((1, 6, 1, 1), dtype="float32") =
R.reshape(conv_bias, [1, 6, 1, 1])
+ lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
+ gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ from tvm.relax.frontend import detach_params
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+ model = Conv2D1()
+ exported_program = torch.export.export(model, example_args)
+ mod = from_exported_program(exported_program, keep_params_as_input=True)
+ mod, params = detach_params(mod)
+ tvm.ir.assert_structural_equal(mod, expected1)
+ func = mod["main"]
+ params = params["main"]
+
+ assert len(params) == len(func.params) - 1
+ for param_var, param_ndarray in zip(func.params[:-1], params):
+ assert tuple(x.value for x in param_var.struct_info.shape.values) ==
param_ndarray.shape
+ assert param_var.struct_info.dtype == param_ndarray.dtype
+
+ tvm.testing.assert_allclose(params[0].numpy(),
model.conv.weight.detach().detach().numpy())
+ tvm.testing.assert_allclose(params[1].numpy(),
model.conv.bias.detach().detach().numpy())
+
+
+def test_unwrap_unit_return_tuple():
+ class Identity(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return (x,)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32")
+ ) -> R.Tensor((256, 256), dtype="float32"):
+ with R.dataflow():
+ gv: R.Tensor((256, 256), dtype="float32") = inp_0
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(256, 256, dtype=torch.float32),)
+ exported_program = export(Identity(), args=example_args)
+ mod = from_exported_program(exported_program,
unwrap_unit_return_tuple=True)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_no_bind_return_tuple():
+ class Identity(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return (x, y)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((256, 256), dtype="float32"),
+ inp_1: R.Tensor((256, 256), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((256, 256), dtype="float32"), R.Tensor((256,
256), dtype="float32")):
+ with R.dataflow():
+ gv: R.Tensor((256, 256), dtype="float32") = inp_0
+ gv1: R.Tensor((256, 256), dtype="float32") = inp_1
+ R.output(gv, gv1)
+ return (gv, gv1)
+
+ example_args = (
+ torch.randn(256, 256, dtype=torch.float32),
+ torch.randn(256, 256, dtype=torch.float32),
+ )
+ exported_program = export(Identity(), args=example_args)
+ mod = from_exported_program(exported_program, no_bind_return_tuple=True)
+ tvm.ir.assert_structural_equal(mod, Expected)