This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 877fbf1867 [Unity] Initial PyTorch Frontend (#14037)
877fbf1867 is described below
commit 877fbf186718a241dbb1730eacec3c0b8496765d
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Feb 18 11:19:48 2023 -0500
[Unity] Initial PyTorch Frontend (#14037)
[Unity] Initial PyTorch Frontend
This PR introduces initial pytorch frontend components of Relax, including
- a FX translator that translates a Torch FX graph module to an TVM
IRModule,
- a Relax-backend of Torch Dynamo, which brings the mechanism to build
PyTorch model using Relax compilation pipeline,
- a pipeline prototype that contains the collection of pre-defined
pipelines that optimizes and lower IRModule before passing to minimum build.
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Tianqi Chen <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
---
python/tvm/relax/__init__.py | 3 +
python/tvm/relax/frontend/__init__.py | 19 +
python/tvm/relax/frontend/torch/__init__.py | 21 +
python/tvm/relax/frontend/torch/dynamo.py | 156 ++
python/tvm/relax/frontend/torch/fx_translator.py | 820 ++++++++++
python/tvm/relax/pipeline.py | 84 ++
tests/python/relax/test_frontend_dynamo.py | 198 +++
tests/python/relax/test_frontend_from_fx.py | 1729 ++++++++++++++++++++++
tests/python/relax/test_pipeline.py | 45 +
9 files changed, 3075 insertions(+)
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index cfcf7876dc..33a9c2eece 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -73,6 +73,9 @@ from .struct_info import (
FuncStructInfo,
)
+# pipeline
+from .pipeline import get_pipeline
+
# Import submodules in the last to avoid dependency
from . import exec_builder
from . import expr
diff --git a/python/tvm/relax/frontend/__init__.py
b/python/tvm/relax/frontend/__init__.py
new file mode 100644
index 0000000000..6c9c188aaa
--- /dev/null
+++ b/python/tvm/relax/frontend/__init__.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Frontends for constructing Relax programs, with the model importers
+"""
diff --git a/python/tvm/relax/frontend/torch/__init__.py
b/python/tvm/relax/frontend/torch/__init__.py
new file mode 100644
index 0000000000..55da5a456d
--- /dev/null
+++ b/python/tvm/relax/frontend/torch/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+PyTorch Frontends for constructing Relax programs, with the model importers
+"""
+from .fx_translator import from_fx
+from .dynamo import relax_dynamo, dynamo_capture_subgraphs
diff --git a/python/tvm/relax/frontend/torch/dynamo.py
b/python/tvm/relax/frontend/torch/dynamo.py
new file mode 100644
index 0000000000..94de73a431
--- /dev/null
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, missing-function-docstring, not-callable
+# pylint: disable=import-outside-toplevel, unused-argument
+# mypy: ignore-errors
+"""PyTorch Dynamo backend of Relax."""
+import functools
+from typing import Optional
+
+import tvm
+from tvm.relax.vm import build as relax_build
+from tvm.relax.frontend.torch.fx_translator import from_fx
+
+
+def device_from_inputs(example_inputs):
+ for x in example_inputs:
+ if hasattr(x, "device"):
+ return x.device
+ return None
+
+
+def relax_dynamo(pipeline: Optional[tvm.transform.Pass] = None):
+ """A helper function to create a relax backend.
+
+ Parameters
+ ----------
+ pipeline : Optional[tvm.transform.Pass]
+ The pipeline to be applied to the relax module before sent to build.
+
+ Returns
+ -------
+ backend : Callable[[torch.fx.GraphModule, List[torch.Tensor]], Callable]
+ The relax dynamo backend.
+ """
+
+ def _relax_backend(graph_module, example_inputs):
+ import torch # type: ignore[import]
+
+ assert isinstance(graph_module, torch.fx.GraphModule)
+
+ def to_torch_tensor(nd_tensor):
+ """A helper function to transfer a NDArray to torch.tensor."""
+ if isinstance(nd_tensor, tvm.nd.NDArray):
+ return torch.from_numpy(nd_tensor.numpy())
+ elif isinstance(nd_tensor, tvm.ir.Array):
+ return tuple(to_torch_tensor(x) for x in nd_tensor)
+ else:
+ raise ValueError(f"Unsupported type {type(nd_tensor)}")
+
+ def to_tvm_tensor(torch_tensor):
+ """A helper function to transfer a torch.tensor to NDArray."""
+ if not isinstance(torch_tensor,
torch._subclasses.fake_tensor.FakeTensor):
+ return tvm.nd.array(torch_tensor.numpy())
+ # Fake Tensor
+ real_tensor = torch.randn(torch_tensor.shape,
dtype=torch_tensor.dtype)
+ return tvm.nd.array(real_tensor.numpy())
+
+ device = device_from_inputs(example_inputs)
+ input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in
example_inputs]
+ mod = from_fx(graph_module, input_info)
+
+ if device.type == "cuda":
+ dev = tvm.cuda(device.index)
+ target = tvm.target.cuda()
+ else:
+ dev = tvm.cpu(0)
+ target = tvm.target.Target(llvm_target())
+
+ # invoke optimization pipeline.
+ if pipeline is None:
+ # get default pipeline
+ seq = tvm.relax.get_pipeline()
+ elif isinstance(pipeline, str):
+ # lookup by name
+ seq = tvm.relax.get_pipeline(pipeline)
+ else:
+ seq = pipeline
+
+ mod = mod.with_attr("target", target)
+ mod = seq(mod)
+
+ ex = relax_build(mod, target=target)
+
+ vm = tvm.relax.vm.VirtualMachine(exec=ex.mod, device=dev)
+
+ def exec_tvm(*i_args):
+ args = [a.contiguous() for a in i_args]
+ vm_args = list()
+ for arg in args:
+ if arg.dim() != 0:
+ if arg.requires_grad:
+ arg = arg.detach()
+ vm_args.append(to_tvm_tensor(arg))
+ outputs = vm["main"](*vm_args)
+ return to_torch_tensor(outputs)
+
+ return exec_tvm
+
+ return _relax_backend
+
+
+def dynamo_capture_subgraphs(model, *params) -> tvm.ir.IRModule:
+ """Capture subgraphs of the PyTorch model using torch.compile into an
IRModule.
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ The PyTorch model to be captured.
+
+ params : List[torch.Tensor]
+ The parameters of the PyTorch model.
+
+ Returns
+ -------
+ mod : tvm.ir.IRModule
+ The IRModule that contains captured subgraphs.
+ """
+ import torch # type: ignore[import]
+ from torch import fx # type: ignore[import]
+ from torch import _dynamo as dynamo # type: ignore[import]
+
+ mod = tvm.IRModule()
+
+ def _capture(graph_module: fx.GraphModule, example_inputs):
+ assert isinstance(graph_module, torch.fx.GraphModule)
+ input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in
example_inputs]
+ subgraph = from_fx(graph_module, input_info)
+ mod["subgraph_" + str(len(mod.get_global_vars()))] = subgraph["main"]
+ return graph_module.forward
+
+ dynamo.reset()
+ compiled_model = torch.compile(model, backend=_capture)
+ compiled_model(*params)
+ return mod
+
+
[email protected]_cache(None)
+def llvm_target():
+ if "avx512" in open("/proc/cpuinfo").read():
+ return "llvm -mcpu=skylake-avx512"
+ return "llvm -mcpu=core-avx2"
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
new file mode 100644
index 0000000000..582f2edbcf
--- /dev/null
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -0,0 +1,820 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name, inconsistent-return-statements,
unidiomatic-typecheck
+# pylint: disable=import-outside-toplevel
+"""PyTorch FX frontend of Relax."""
+from typing import Callable, Dict, List, Tuple, Union
+from functools import reduce
+
+import tvm
+from tvm import relax
+
+
+class TorchFXImporter:
+ """An importer from PyTorch FX to Relax."""
+
+ import torch # type: ignore
+ from torch import fx
+
+ def __init__(self) -> None:
+ import torch # type: ignore
+ from torch import fx
+
+ self.env: Dict[fx.node.Node, relax.Expr] = {}
+ self.params: Dict[torch.Tensor, relax.Constant] = {}
+ self.named_modules: Dict[str, torch.Module] = None
+ self.block_builder: relax.BlockBuilder = None
+ self.create_convert_map()
+
+ ########## Utilities ##########
+ @staticmethod
+ def _fetch_attr(model, target: str):
+ import torch # type: ignore
+
+ target_atoms = target.split(".")
+ attr_itr = model
+ for i, atom in enumerate(target_atoms):
+ if not hasattr(attr_itr, atom):
+ raise RuntimeError(
+ f"Node referenced non existing target
{'.'.join(target_atoms[:i])}"
+ )
+ attr_itr = getattr(attr_itr, atom)
+ if isinstance(attr_itr, torch.Tensor):
+ return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr)
+ return attr_itr
+
+ @staticmethod
+ def _convert_data_type(input_type):
+ """converts the PyTorch scalar type input_type to a TVM dtype."""
+ import torch # type: ignore
+
+ input_type = input_type.lower()
+ if input_type in ["float", "float32", "torch.float32", torch.float32]:
+ return "float32"
+ elif input_type in ["float16", "torch.float16", torch.float16]:
+ return "float16"
+ elif input_type in ["int64", "torch.int64", torch.int64]:
+ return "int64"
+ else:
+ raise NotImplementedError("input_type {} is not handled
yet".format(input_type))
+
+ @staticmethod
+ def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var:
+ tensor = tensor.detach().cpu()
+ shape = tensor.data.shape
+ dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype))
+ return relax.const(tensor.data.numpy(), relax.TensorStructInfo(shape,
dtype))
+
+ @staticmethod
+ def shape_of(tensor):
+ """Get the shape of a tensor."""
+ import torch # type: ignore
+
+ if isinstance(tensor, relax.Expr):
+ if not isinstance(tensor.struct_info, relax.TensorStructInfo):
+ raise TypeError("The input Expr of shape_of should be a
Tensor")
+ return tensor.struct_info.shape
+ elif isinstance(tensor, torch.Tensor):
+ return tensor.shape
+ raise ValueError("Unsupported type: {}".format(type(tensor)))
+
+ def retrieve_args(self, node):
+ return self._retrieve_args(node.args)
+
+ def _retrieve_args(self, node):
+ from torch import fx
+
+ if isinstance(node, fx.node.Node):
+ return self.env[node]
+ elif isinstance(node, tuple):
+ return tuple(self._retrieve_args(x) for x in node)
+ elif isinstance(node, list):
+ return [self._retrieve_args(x) for x in node]
+ elif isinstance(node, dict):
+ return {self._retrieve_args(k): self._retrieve_args(v) for k, v in
node.items()}
+ else:
+ return node
+
+ @staticmethod
+ def _promote_binary_op_args(lhs, rhs):
+ if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
+ return lhs, rhs
+ elif isinstance(lhs, relax.Expr):
+ assert isinstance(lhs.struct_info, relax.TensorStructInfo)
+ return lhs, relax.const(rhs, lhs.struct_info.dtype)
+ elif isinstance(rhs, relax.Expr):
+ assert isinstance(rhs.struct_info, relax.TensorStructInfo)
+ return relax.const(lhs, rhs.struct_info.dtype), rhs
+ else:
+ assert False
+
+ def _call_binary_op(self, op, lhs, rhs):
+ lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs)
+ return self.block_builder.emit(op(lhs, rhs))
+
+ ########## Arithmetic ##########
+
+ def _cos(self, node: fx.node.Node) -> relax.Var:
+ return self.block_builder.emit(relax.op.cos(self.env[node.args[0]]))
+
+ def _sin(self, node: fx.node.Node) -> relax.Var:
+ return self.block_builder.emit(relax.op.sin(self.env[node.args[0]]))
+
+ def _sqrt(self, node: fx.node.Node) -> relax.Expr:
+ arg = self.env[node.args[0]]
+ if isinstance(arg, (int, float)):
+ arg = relax.const(arg, "float32")
+ return self.block_builder.emit(relax.op.sqrt(arg))
+
+ def _add(self, node: fx.node.Node) -> relax.Expr:
+ lhs, rhs = self.retrieve_args(node)
+ if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+ return self._call_binary_op(relax.op.add, lhs, rhs)
+ return lhs + rhs
+
+ def _floordiv(self, node: fx.node.Node) -> relax.Expr:
+ lhs, rhs = self.retrieve_args(node)
+ if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+ return self._call_binary_op(relax.op.floor_divide, lhs, rhs)
+ return lhs // rhs
+
+ def _mul(self, node: fx.node.Node) -> relax.Expr:
+ lhs, rhs = self.retrieve_args(node)
+ if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+ return self._call_binary_op(relax.op.multiply, lhs, rhs)
+ return lhs * rhs
+
+ def _sub(self, node: fx.node.Node) -> relax.Expr:
+ lhs, rhs = self.retrieve_args(node)
+ if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+ return self._call_binary_op(relax.op.subtract, lhs, rhs)
+ return lhs - rhs
+
+ def _truediv(self, node: fx.node.Node) -> relax.Expr:
+ lhs, rhs = self.retrieve_args(node)
+ if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
+ return self._call_binary_op(relax.op.divide, lhs, rhs)
+ return lhs / rhs
+
+ def _clamp(self, node: fx.node.Node) -> relax.Expr:
+ args = self.retrieve_args(node)
+ a_min = node.kwargs["min"]
+ a_max = node.kwargs["max"]
+ if not isinstance(a_min, (int, float)):
+ raise ValueError(
+ f"TVM only supports constant min value for torch.clamp/clip, "
+ f"but got {a_min} with type {type(a_min)}"
+ )
+ if not isinstance(a_max, (int, float)):
+ raise ValueError(
+ f"TVM only supports constant max value for torch.clamp/clip, "
+ f"but got {a_max} with type {type(a_max)}"
+ )
+ return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
+
+ ########## Compare ##########
+
+ def _lt(self, node: fx.node.Node) -> relax.Expr:
+ lhs, rhs = self.retrieve_args(node)
+ return self._call_binary_op(relax.op.less, lhs, rhs)
+
+ ########## Creation ##########
+
+ def _tril(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ k = node.args[1] if len(node.args) > 1 else 0
+ assert isinstance(k, int)
+ return self.block_builder.emit(relax.op.create.tril(x, k))
+
+ def _new_ones(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ self_var = args[0]
+ size = args[1:]
+ if not isinstance(size, (list, tuple)):
+ size = (size,)
+ size = relax.ShapeExpr(size)
+ return self.block_builder.emit(
+ relax.op.full(
+ size,
+ relax.const(1, self_var.struct_info.dtype),
+ self_var.struct_info.dtype,
+ )
+ )
+
+ ########## Statistical ##########
+
+ def _sum(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ if len(args) == 1:
+ return self.block_builder.emit(relax.op.sum(args[0]))
+ return self.block_builder.emit(relax.op.sum(args[0], args[1]))
+
+ ########## DataType ##########
+
+ def _float(self, node: fx.node.Node) -> relax.Var:
+ return self.block_builder.emit(relax.op.astype(self.env[node.args[0]],
"float32"))
+
+ def _half(self, node: fx.node.Node) -> relax.Var:
+ return self.block_builder.emit(relax.op.astype(self.env[node.args[0]],
"float16"))
+
+ def _type(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ return self.block_builder.emit(relax.op.astype(args[0], args[1]))
+
+ ########## Linear Algebra ##########
+
+ def _matmul_impl(self, a: relax.Expr, b: relax.Expr):
+ return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b,
out_dtype="float32"))
+
+ def _matmul(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ res = self._matmul_impl(
+ args[0],
+ args[1],
+ )
+ return res
+
+ def _addmm(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ y = self.env[node.args[1]]
+ z = self.env[node.args[2]]
+ matmul = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z,
out_dtype="float32"))
+ return self.block_builder.emit(relax.op.add(x, matmul))
+
+ ########## Manipulation ##########
+
+ def _cat(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ return self.block_builder.emit(relax.op.concat(args[0],
axis=node.kwargs["dim"]))
+
+ def _expand(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ return self.block_builder.emit(relax.op.broadcast_to(args[0],
args[1:]))
+
+ def _flatten(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ if node.target in self.named_modules:
+ module = self.named_modules[node.target]
+ start_dim = module.start_dim
+ end_dim = module.end_dim
+ else:
+ start_dim = node.args[1] if len(node.args) >= 2 else 0
+ end_dim = node.args[2] if len(node.args) == 3 else -1
+ shape = self.shape_of(x)
+ start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim
+ end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim
+ flattened = reduce(lambda x, y: x * y, [shape[i] for i in
range(start_dim, end_dim + 1)])
+ new_shape = (
+ [shape[i] for i in range(0, start_dim)]
+ + [flattened]
+ + [shape[i] for i in range(end_dim + 1, len(shape))]
+ )
+ return self.block_builder.emit(relax.op.reshape(x, new_shape))
+
+ def _permute(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ return self.block_builder.emit(relax.op.permute_dims(args[0],
args[1:]))
+
+ def _reshape(self, node: fx.node.Node) -> relax.Var:
+ import torch # type: ignore
+
+ args = self.retrieve_args(node)
+ if isinstance(args[1], (torch.Size, tuple, list)):
+ return self.block_builder.emit(relax.op.reshape(args[0],
tuple(args[1])))
+ return self.block_builder.emit(relax.op.reshape(args[0], args[1:]))
+
+ def _split(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ split_size = node.args[1]
+ if "dim" in node.kwargs:
+ dim = node.kwargs["dim"]
+ else:
+ dim = 0
+ n_section = (self.shape_of(x)[dim].value + split_size - 1) //
split_size
+ return self.block_builder.emit(relax.op.split(x, n_section, dim))
+
+ def _transpose(self, node: fx.node.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ full_idx = list(range(len(self.shape_of(args[0]))))
+ full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]],
full_idx[args[1]]
+ return self.block_builder.emit(relax.op.permute_dims(args[0],
full_idx))
+
+ ########## Neural Network ##########
+
+ def _linear(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ weight = self.params[module.weight]
+ bias = None if module.bias is None else self.params[module.bias]
+ return self.block_builder.emit(relax.op.linear(x, weight, bias,
"float32"))
+
+ def _conv2d(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ weight = self.params[module.weight]
+
+ conv2d = self.block_builder.emit(
+ relax.op.nn.conv2d(
+ x,
+ weight,
+ strides=module.stride,
+ padding=module.padding,
+ dilation=module.dilation,
+ groups=module.groups,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_dtype="float32",
+ )
+ )
+
+ if module.bias is None:
+ return conv2d
+
+ bias = self.params[module.bias]
+ assert len(self.shape_of(bias)) == 1
+ bias = relax.op.reshape(bias, (1, -1, 1, 1))
+
+ return self.block_builder.emit(relax.op.add(conv2d, bias))
+
+ def _max_pool2d(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ if node.target in self.named_modules:
+ module = self.named_modules[node.target]
+ kernel = module.kernel_size
+ stride = module.stride
+ padding = module.padding
+ dilation = module.dilation
+ ceil_mode = module.ceil_mode
+ else:
+ nargs = len(node.args)
+ kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"]
+ stride = node.args[2] if nargs > 2 else node.kwargs["stride"]
+ padding = node.args[3] if nargs > 3 else node.kwargs["padding"]
+ dilation = node.args[4] if nargs > 4 else node.kwargs["dilation"]
+ ceil_mode = node.args[5] if nargs > 5 else node.kwargs["ceil_mode"]
+
+ stride = kernel if stride is None else stride
+
+ return self.block_builder.emit(
+ relax.op.nn.max_pool2d(
+ x,
+ pool_size=kernel,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ layout="NCHW",
+ ceil_mode=ceil_mode,
+ )
+ )
+
+ def _adaptive_avg_pool2d(self, is_module: bool) -> Callable:
+ from torch import fx
+
+ def _impl(node: fx.node.Node) -> relax.Var:
+ if is_module:
+ module = self.named_modules[node.target]
+ x = self.env[node.args[0]]
+ output_size = module.output_size
+ else:
+ x = self.env[node.args[0]]
+ output_size = node.args[1]
+ return self.block_builder.emit(
+ relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW")
+ )
+
+ return _impl
+
+ def _softmax(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ if node.target in self.named_modules:
+ module = self.named_modules[node.target]
+ dim = module.dim
+ else:
+ nargs = len(node.args)
+ dim = node.args[1] if nargs > 1 else node.kwargs["dim"]
+ assert dim is not None
+ return self.block_builder.emit(relax.op.nn.softmax(x, dim))
+
+ def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ weight = self.params[module.weight]
+ bias = self.params[module.bias]
+ dtype = self._convert_data_type(str(module.running_mean.dtype))
+ running_mean = relax.const(module.running_mean.cpu().detach().numpy(),
dtype)
+ running_var = relax.const(module.running_var.cpu().detach().numpy(),
dtype)
+ eps = module.eps
+
+ res_tuple = self.block_builder.emit(
+ relax.op.nn.batch_norm(
+ x,
+ weight,
+ bias,
+ running_mean,
+ running_var,
+ axis=1,
+ epsilon=eps,
+ )
+ )
+
+ return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0))
+
+ def _layer_norm(self, node: fx.node.Node) -> relax.Var:
+ import torch # type: ignore
+
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+
+ if module.elementwise_affine:
+ gamma = self.params[module.weight]
+ beta = self.params[module.bias]
+ else:
+ gamma = relax.const(torch.ones_like(module.normalized_shape),
x.checked_type)
+ beta = relax.const(torch.zeros_like(module.normalized_shape),
x.checked_type)
+ dim_num = len(module.normalized_shape)
+ axes = list(range(-dim_num, 0))
+
+ return self.block_builder.emit(
+ relax.op.nn.layer_norm(
+ x,
+ gamma,
+ beta,
+ axes=axes,
+ epsilon=module.eps,
+ )
+ )
+
+ def _group_norm(self, node: fx.node.Node) -> relax.Var:
+ # torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05,
+ # affine=True, device=None, dtype=None)
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ num_groups = module.num_groups
+ num_channels = module.num_channels
+ eps = module.eps
+ affine = module.affine
+
+ shape = self.shape_of(x)
+ assert len(shape) == 4
+ N, C, H, W = shape[0], shape[1], shape[2], shape[3]
+ assert C == num_channels
+ assert C % num_groups == 0
+ grouped_x = self.block_builder.emit(
+ relax.op.reshape(x, [N, num_groups, C // num_groups, H, W])
+ )
+ mean_x = self.block_builder.emit(relax.op.mean(grouped_x, [2, 3, 4],
keepdims=True))
+ sub_x = self.block_builder.emit(relax.op.subtract(grouped_x, mean_x))
+ square_x = self.block_builder.emit(relax.op.multiply(sub_x, sub_x))
+ sum_square_x = self.block_builder.emit(relax.op.sum(square_x, [2, 3,
4], keepdims=True))
+ var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C //
num_groups * H * W).value)
+ var_x_eps = self._call_binary_op(relax.op.add, var_x, eps)
+ std_x = self.block_builder.emit(relax.op.sqrt(var_x_eps))
+ norm_x = self.block_builder.emit(relax.op.divide(sub_x, std_x))
+
+ if affine:
+ weight = self.params[module.weight]
+ bias = self.params[module.bias]
+ weight_reshape = self.block_builder.emit(
+ relax.op.reshape(weight, (1, num_groups, C // num_groups, 1,
1))
+ )
+ bias_reshape = self.block_builder.emit(
+ relax.op.reshape(bias, (1, num_groups, C // num_groups, 1, 1))
+ )
+ norm_x = self.block_builder.emit(relax.op.multiply(norm_x,
weight_reshape))
+ norm_x = self.block_builder.emit(relax.op.add(norm_x,
bias_reshape))
+ return self.block_builder.emit(relax.op.reshape(norm_x, (N, C, H, W)))
+
+ def _embedding(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ module = self.named_modules[node.target]
+ weight = self.params[module.weight]
+ x = self.block_builder.emit(relax.op.astype(x, "int32"))
+ return self.block_builder.emit(relax.op.take(weight, x, axis=0))
+
+ def _interpolate(self, node: fx.node.Node) -> relax.Var:
+ # torch.nn.functional.interpolate(
+ # input, size=None, scale_factor=None, mode='nearest',
align_corners=None,
+ # recompute_scale_factor=None, antialias=False)
+ # (TODO) this is a temporary implementation for interpolate that only
considers NCHW layout
+ # it basically replicates the implementation in
tvm.relay.frontend.pytorch
+ data = self.env[node.args[0]]
+ size = node.kwargs["size"]
+ scale_factor = node.kwargs["scale_factor"]
+ method = node.kwargs["mode"]
+ align_corners = node.kwargs["align_corners"]
+ recompute_scale_factor = node.kwargs["recompute_scale_factor"]
+ antialias = node.kwargs["antialias"]
+
+ assert recompute_scale_factor is None
+ assert antialias is False
+
+ if size is None:
+ shape = self.shape_of(data)
+ assert isinstance(shape, relax.ShapeExpr)
+ size = tuple(int(shape[i].value * scale_factor) for i in range(2,
len(shape)))
+
+ if method.startswith("nearest"):
+ method = "nearest_neighbor"
+ elif method[0:2] == "bi":
+ method = method[2:]
+
+ if method == "nearest_neighbor":
+ coord_trans = "asymmetric"
+ elif align_corners:
+ coord_trans = "align_corners"
+ else:
+ coord_trans = "half_pixel"
+
+ return self.block_builder.emit(
+ relax.op.image.resize2d(
+ data, size, layout="NCHW", method=method,
coordinate_transformation_mode=coord_trans
+ )
+ )
+
+ ########## Others ##########
+
+ def _size(self, node: fx.node.Node) -> relax.Expr:
+ x = self.env[node.args[0]]
+ shape = self.shape_of(x)
+ if len(node.args) == 1:
+ assert isinstance(shape, relax.ShapeExpr)
+ return shape
+ assert len(node.args) == 2
+ idx = node.args[1]
+ return self.shape_of(x)[idx].value
+
+ def _getattr(self, node: fx.node.Node) -> relax.Var:
+ if isinstance(self.env[node.args[0]], relax.Expr):
+ if node.args[1] == "dtype":
+ return self.env[node.args[0]].struct_info.dtype
+ elif node.args[1] == "shape":
+ return self.shape_of(self.env[node.args[0]])
+ return getattr(self.env[node.args[0]], node.args[1])
+
+ def _getitem(self, node: fx.node.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)):
+ return x[node.args[1]]
+ elif isinstance(x, relax.Var):
+ if isinstance(x.struct_info, relax.TupleStructInfo):
+ return self.block_builder.emit(relax.TupleGetItem(x,
node.args[1]))
+
+ assert isinstance(x.struct_info, relax.TensorStructInfo)
+ begin = []
+ end = []
+ stride = []
+ axes = []
+ expand_dim = []
+ i = 0
+ shape = self.shape_of(x)
+ for index in node.args[1]:
+ if isinstance(index, int):
+ begin.append(index)
+ end.append(index + 1)
+ stride.append(1)
+ axes.append(i)
+ i = i + 1
+ elif isinstance(index, slice):
+ begin.append(0 if index.start is None else index.start)
+ end.append(shape[i] if index.stop is None else index.stop)
+ stride.append(1 if index.step is None else index.step)
+ axes.append(i)
+ i = i + 1
+ elif index is None:
+ expand_dim.append(i)
+ i = i + 1
+ else:
+ raise ValueError("Unsupported index type: " +
str(type(index)))
+ while i < len(shape):
+ begin.append(0)
+ end.append(shape[i])
+ axes.append(i)
+ i = i + 1
+ sliced = self.block_builder.emit(relax.op.strided_slice(x, axes,
begin, end, stride))
+ sliced_shape = list(self.shape_of(sliced))
+ for i in expand_dim:
+ sliced_shape.insert(i, 1)
+ return self.block_builder.emit(relax.op.reshape(sliced,
sliced_shape))
+ else:
+ assert False
+
+ def create_convert_map(self):
+ from torch import nn
+ from torch import fx
+
+ self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node],
relax.Var]] = {
+ # call_module
+ nn.Linear: self._linear,
+ nn.Conv2d: self._conv2d,
+ nn.MaxPool2d: self._max_pool2d,
+ nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True),
+ nn.Softmax: self._softmax,
+ nn.ReLU: lambda node:
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
+ nn.ReLU6: lambda node: self.block_builder.emit(
+ relax.op.clip(self.env[node.args[0]], 0, 6)
+ ),
+ nn.SiLU: lambda node:
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
+ nn.Flatten: self._flatten,
+ nn.BatchNorm2d: self._batch_norm_2d,
+ nn.LayerNorm: self._layer_norm,
+ nn.GroupNorm: self._group_norm,
+ nn.Dropout: lambda node: self.env[node.args[0]],
+ nn.modules.sparse.Embedding: self._embedding,
+ # call_function and call_method
+ "cos": self._cos,
+ "sin": self._sin,
+ "add": self._add,
+ "floordiv": self._floordiv,
+ "mul": self._mul,
+ "sub": self._sub,
+ "sqrt": self._sqrt,
+ "lt": self._lt,
+ "truediv": self._truediv,
+ "new_ones": self._new_ones,
+ "tril": self._tril,
+ "sum": self._sum,
+ "float": self._float,
+ "half": self._half,
+ "type": self._type,
+ "matmul": self._matmul,
+ "addmm": self._addmm,
+ "cat": self._cat,
+ "expand": self._expand,
+ "flatten": self._flatten,
+ "permute": self._permute,
+ "reshape": self._reshape,
+ "split": self._split,
+ "transpose": self._transpose,
+ "unsqueeze": lambda node: self.block_builder.emit(
+ relax.op.expand_dims(self.env[node.args[0]], node.args[1])
+ ),
+ "view": self._reshape,
+ "softmax": self._softmax,
+ "clamp": self._clamp,
+ "relu": lambda node:
self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])),
+ "gelu": lambda node:
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])),
+ "interpolate": self._interpolate,
+ "size": self._size,
+ "getattr": self._getattr,
+ "getitem": self._getitem,
+ "contiguous": lambda node: self.env[node.args[0]],
+ "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False),
+ }
+
+ def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) ->
tvm.IRModule:
+ """Convert a PyTorch FX GraphModule to a Relax program."""
+ from torch import fx
+
+ self.named_modules = dict(model.named_modules())
+
+ graph: fx.Graph = model.graph
+ # Create input variables.
+ inputs = list()
+ for idx, (shape, dtype) in enumerate(input_info):
+ inputs.append(
+ relax.Var(
+ f"inp_{idx}", relax.TensorStructInfo(shape,
self._convert_data_type(dtype))
+ )
+ )
+
+ # Initialize the block builder with a function and a dataflow block.
+ self.block_builder = relax.BlockBuilder()
+ with self.block_builder.function(name="main", params=inputs.copy()):
+ output = None
+ with self.block_builder.dataflow():
+ # Translate model parameters.
+ for _, param in model.named_parameters():
+ shape = param.data.shape
+ dtype = self._convert_data_type(str(param.data.dtype))
+ if dtype in ("float32", "float16"):
+ self.params[param] = relax.const(
+ param.data.cpu().numpy(),
relax.TensorStructInfo(shape, dtype)
+ )
+ else:
+ raise ValueError("Unsupported data type for model
parameters: %s" % dtype)
+ # Translate the model.
+ for node in graph.nodes:
+ if node.op == "placeholder":
+ assert len(inputs) > 0, "Provided inputs is less than
actual inputs"
+ self.env[node] = inputs.pop(0)
+ elif node.op == "output":
+ args = self.retrieve_args(node)
+ output = self.block_builder.emit_output(args[0])
+ break
+ elif node.op == "get_attr":
+ self.env[node] = TorchFXImporter._fetch_attr(model,
node.target)
+ elif node.op == "call_module":
+ module = self.named_modules[node.target]
+ assert (
+ type(module) in self.convert_map
+ ), f"Unsupported module type {type(module)}"
+ self.env[node] = self.convert_map[type(module)](node)
+ elif node.op == "call_function":
+ func_name = node.name.rstrip("0123456789_")
+ assert (
+ func_name in self.convert_map
+ ), f"Unsupported function type {func_name}"
+ self.env[node] = self.convert_map[func_name](node)
+ elif node.op == "call_method":
+ assert (
+ node.target in self.convert_map
+ ), f"Unsupported function target {node.target}"
+ self.env[node] = self.convert_map[node.target](node)
+ else:
+ raise ValueError(f"Unsupported op {node.op}")
+ assert output is not None
+ self.block_builder.emit_func_output(output)
+
+ return self.block_builder.get()
+
+
+def from_fx(model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule:
+ """Convert a PyTorch FX GraphModule to a Relax program
+
+ Parameters
+ ----------
+ model : fx.GraphModule
+ The PyTorch FX GraphModule to convert.
+
+ input_info : List[Tuple[Tuple[int], str]]
+ A list of shapes and data types of input tensors.
+
+ Returns
+ -------
+ module : tvm.IRModule
+ The converted Relax program.
+
+ Examples
+ --------
+ Users can use the FX tracer or dynamo.export() to extract
+ a fx.GraphModule from a PyTorch model. The following codes show
+ how to convert a PyTorch model to a Relax program.
+
+ .. code-block:: python
+
+ # Import the importer.
+ import numpy as np
+ import torch
+ from tvm.relax.frontend.torch_fx import from_fx
+ from torch import _dynamo as dynamo
+
+ # Define the module
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(in_features=10, out_features=7,
bias=True)
+
+ def forward(self, input):
+ return self.linear(input)
+
+ # Instantiate the model and create the input info dict.
+ torch_model = MyModule()
+ input_info = [((128, 10), "float32")]
+ input_tensors = [
+ torch.astensor(np.random.randn(*shape).astype(dtype))
+ for shape, dtype in input_info
+ ]
+
+ # Use FX tracer to trace the PyTorch model.
+ graph_module = fx.symbolic_trace(torch_model)
+
+ # Use the dynamo.export() to export the PyTorch model to FX.
+ try:
+ graph_module = dynamo.export(torch_model, *input_tensors)
+ except:
+ raise RuntimeError("Failed to export the PyTorch model to FX.")
+
+ # Use the importer to import the PyTorch model to Relax.
+ mod: tvm.IRModule = from_fx(graph_module, input_info)
+
+ # Print out the imported model.
+ print(mod.script())
+
+ Notes
+ -----
+ For a given PyTorch model, to lookup the names of the model inputs in
+ FX, one can use
+
+ .. code-block:: python
+
+ fx.symbolic_trace(model).graph.print_tabular()
+
+ to print out the tabular representation of the PyTorch module, and then
+ check the placeholder rows in the beginning of the tabular.
+ """
+ return TorchFXImporter().from_fx(model, input_info)
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
new file mode 100644
index 0000000000..a5da15b76d
--- /dev/null
+++ b/python/tvm/relax/pipeline.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Pre-defined pipelines.
+
+oRelax enables flexible pipeline optimizations before min build.
+This namespace offers a pre-defined collection that can be used
+as it is or serves as a basis to do further composition.
+"""
+# pylint: disable=unused-argument
+import tvm
+from tvm import meta_schedule as ms
+from . import transform
+
+
[email protected]_pass(opt_level=0)
+def zero_pipeline(mod: tvm.ir.IRModule, ctx: tvm.transform.PassContext) ->
tvm.ir.IRModule:
+ """Pipeline that applies pre-tuned logs.
+
+ Parameters
+ ----------
+ mod : tvm.ir.IRModule
+ Input IRModule.
+
+ ctx : tvm.transform.PassContext
+ The pass context
+
+ Returns
+ -------
+ mod: tvm.ir.IRModule
+ The result transformed module.
+ """
+ seq = tvm.transform.Sequential(
+ [
+ transform.LegalizeOps(),
+ transform.AnnotateTIROpPattern(),
+ transform.FoldConstant(),
+ transform.FuseOps(),
+ transform.FuseTIR(),
+ ]
+ )
+ mod = seq(mod)
+ if ms.Database.current():
+ mod = transform.MetaScheduleApplyDatabase()(mod)
+ return mod
+
+
+# global map of pre-built pipelines
+PIPELINE_MAP = {"zero": zero_pipeline}
+
+
+def get_pipeline(name: str = "zero") -> tvm.transform.Pass:
+ """Get pre-build pipeline by name
+
+ Parameters
+ ----------
+ name : Optional[str]
+ Name of the pipeline
+
+ Returns
+ -------
+ pipeline: tvm.transform.Pass
+ The transformation pipeline.
+ """
+
+ if name in PIPELINE_MAP:
+ return PIPELINE_MAP[name]
+ else:
+ raise ValueError(
+ f"Unknown pre-built pipeline {name}," f"candidates are
{list(PIPELINE_MAP.keys())}"
+ )
diff --git a/tests/python/relax/test_frontend_dynamo.py
b/tests/python/relax/test_frontend_dynamo.py
new file mode 100644
index 0000000000..370df2103d
--- /dev/null
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+pytest.importorskip("torch._dynamo")
+
+
+import tvm
+from tvm import relax, meta_schedule as ms, tir
+import tvm.testing
+import torch
+import torch._dynamo as dynamo
+from tvm.relax.frontend.torch import relax_dynamo
+from tvm.script.parser import relax as R, tir as T
+
+
+def test_relax_dynamo():
+ class Input1(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.lin = torch.nn.Linear(100, 10)
+
+ def forward(self, x):
+ return torch.nn.functional.relu(self.lin(x))
+
+ model = Input1()
+ ### construct the database
+ @tvm.script.ir_module
+ class Input1_ir:
+ @T.prim_func
+ def main(
+ inp_0: T.Buffer[(T.int64(10), T.int64(100)), "float32"],
+ param_0: T.Buffer[(T.int64(100), T.int64(10)), "float32"],
+ param_1: T.Buffer[T.int64(10), "float32"],
+ compute: T.Buffer[(T.int64(10), T.int64(10)), "float32"],
+ ):
+ # function attr dict
+ T.func_attr({"tir.noalias": True, "global_symbol": "main"})
+ # body
+ # with T.block("root")
+ matmul = T.alloc_buffer([T.int64(10), T.int64(10)],
dtype="float32")
+ T_add = T.alloc_buffer([T.int64(10), T.int64(10)], dtype="float32")
+ for i0, i1, k in T.grid(T.int64(10), T.int64(10), T.int64(100)):
+ with T.block("matmul"):
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+ T.reads(inp_0[v_i0, v_k], param_0[v_k, v_i1])
+ T.writes(matmul[v_i0, v_i1])
+ with T.init():
+ matmul[v_i0, v_i1] = T.float32(0)
+ matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + inp_0[v_i0, v_k]
* param_0[v_k, v_i1]
+ for ax0, ax1 in T.grid(T.int64(10), T.int64(10)):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(matmul[v_ax0, v_ax1], param_1[v_ax1])
+ T.writes(T_add[v_ax0, v_ax1])
+ T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + param_1[v_ax1]
+ for i0, i1 in T.grid(T.int64(10), T.int64(10)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(T_add[v_i0, v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.max(T_add[v_i0, v_i1],
T.float32(0))
+
+ db = ms.Database.create("memory")
+ workload = db.commit_workload(Input1_ir)
+
+ sch = tir.Schedule(Input1_ir, debug_mask="all")
+ b0 = sch.get_block(name="matmul", func_name="main")
+ b1 = sch.get_block(name="T_add", func_name="main")
+ b2 = sch.get_block(name="root", func_name="main")
+ sch.compute_inline(block=b1)
+ sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure",
ann_val="SSRSRS")
+ l3, l4, l5 = sch.get_loops(block=b0)
+ v6, v7, v8, v9 = sch.sample_perfect_tile(
+ loop=l3, n=4, max_innermost_factor=64, decision=[1, 2, 5, 1]
+ )
+ l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9],
preserve_unit_iters=True)
+ v14, v15, v16, v17 = sch.sample_perfect_tile(
+ loop=l4, n=4, max_innermost_factor=64, decision=[1, 1, 10, 1]
+ )
+ l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17],
preserve_unit_iters=True)
+ v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64,
decision=[100, 1])
+ l24, l25 = sch.split(loop=l5, factors=[v22, v23], preserve_unit_iters=True)
+ sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21)
+ (b26,) = sch.get_consumers(block=b0)
+ sch.reverse_compute_at(block=b26, loop=l18, preserve_unit_loops=True,
index=-1)
+ sch.annotate(block_or_loop=b2, ann_key="meta_schedule.parallel",
ann_val=96)
+ sch.annotate(block_or_loop=b2, ann_key="meta_schedule.vectorize",
ann_val=64)
+ v27 = sch.sample_categorical(
+ candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0
+ )
+ sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit",
ann_val=v27)
+
+ tuning_record = ms.database.TuningRecord(sch.trace, workload,
run_secs=[0.0])
+ db.commit_tuning_record(tuning_record)
+ ### Optimize the model with tuned-log
+ with db:
+ opt_model = torch.compile(model, backend=relax_dynamo())
+ inp = torch.randn(10, 100)
+ tvm.testing.assert_allclose(
+ opt_model(inp).detach().numpy(), model(inp).detach().numpy(),
rtol=1e-5, atol=1e-5
+ )
+
+
+def test_subgraph_capture():
+ import torch
+ from tvm.relax.frontend.torch.dynamo import dynamo_capture_subgraphs
+
+ class Input1(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.lin = torch.nn.Linear(100, 10)
+
+ def forward(self, x):
+ return torch.nn.functional.relu(self.lin(x))
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def subgraph_0(
+ inp_0: R.Tensor((10, 100), dtype="float32"),
+ w0: R.Tensor((10, 100), dtype="float32"),
+ w1: R.Tensor((10,), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0,
axes=None)
+ lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(inp_0, lv,
out_dtype="float32")
+ lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1)
+ lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2)
+ gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ model = Input1()
+ mod = dynamo_capture_subgraphs(model, torch.randn(10, 100))
+ binding = {"w0": model.lin.weight.detach().numpy(), "w1":
model.lin.bias.detach().numpy()}
+ binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+ expected = relax.transform.BindParams("subgraph_0", binding)(Expected1)
+ tvm.ir.assert_structural_equal(mod, expected)
+
+ def Input2(a, b):
+ x = a / (torch.sin(a) + 1)
+ if torch.sum(b) < 1:
+ b = b * -1
+ return x * b
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def subgraph_0(
+ inp_0: R.Tensor((10,), dtype="float32"), inp_1: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((),
dtype="bool")):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10,), dtype="float32") = R.sin(inp_0)
+ lv1: R.Tensor((10,), dtype="float32") = R.add(lv, R.const(1,
"float32"))
+ lv2: R.Tensor((10,), dtype="float32") = R.divide(inp_0, lv1)
+ lv3: R.Tensor((), dtype="float32") = R.sum(inp_1, axis=None,
keepdims=False)
+ lv4: R.Tensor((), dtype="bool") = R.less(lv3, R.const(1,
"float32"))
+ gv: R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((),
dtype="bool")) = (
+ lv2,
+ lv4,
+ )
+ R.output(gv)
+ return gv
+
+ @R.function
+ def subgraph_1(
+ inp_01: R.Tensor((10,), dtype="float32"), inp_11: R.Tensor((10,),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((10,), dtype="float32")):
+ # block 0
+ with R.dataflow():
+ lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_11,
inp_01)
+ gv1: R.Tuple(R.Tensor((10,), dtype="float32")) = (lv5,)
+ R.output(gv1)
+ return gv1
+
+ mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10))
+ tvm.ir.assert_structural_equal(mod, Expected2)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
new file mode 100644
index 0000000000..9b35d34bd3
--- /dev/null
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -0,0 +1,1729 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+
+import tvm
+from tvm import relax
+import tvm.testing
+from tvm.script.parser import relax as R, tir as T
+
+
+def verify_model(torch_model, input_info, binding, expected):
+ from torch import fx
+ from tvm.relax.frontend.torch import from_fx
+
+ graph_model = fx.symbolic_trace(torch_model)
+ mod = from_fx(graph_model, input_info)
+ binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+ expected = relax.transform.BindParams("main", binding)(expected)
+ tvm.ir.assert_structural_equal(mod, expected)
+
+
[email protected]_gpu
+def test_conv():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ class Conv2D1(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((6, 3, 7, 7), dtype="float32"),
+ w2: R.Tensor((6,), dtype="float32"),
+ ) -> R.Tensor((1, 6, 4, 4), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
+ input_1,
+ w1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1])
+ lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
+ gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3
+ R.output(gv)
+ return gv
+
+ class Conv2D2(Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(3, 6, 7, bias=False)
+
+ def forward(self, input):
+ return self.conv(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((6, 3, 7, 7), dtype="float32"),
+ ) -> R.Tensor((1, 6, 4, 4), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d(
+ input_1,
+ w1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="float32",
+ )
+ gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ model = Conv2D1()
+ binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy()}
+ verify_model(model, input_info, binding, expected1)
+
+ model = Conv2D2()
+ binding = {"w1": model.conv.weight.numpy()}
+ verify_model(model, input_info, binding, expected2)
+
+
[email protected]_gpu
+def test_linear():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ # nn.Linear
+ class Dense1(Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(10, 7, bias=True)
+
+ def forward(self, input):
+ return self.linear(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((7, 10), dtype="float32"),
+ w2: R.Tensor((1, 7), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 7), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1,
axes=None)
+ lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(
+ input_1, lv, out_dtype="float32"
+ )
+ lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2)
+ gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv2
+ R.output(gv)
+ return gv
+
+ class Dense2(Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(10, 7, bias=False)
+
+ def forward(self, input):
+ return self.linear(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((7, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 7), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1,
axes=None)
+ lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(
+ input_1, lv, out_dtype="float32"
+ )
+ gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ model = Dense1()
+ binding = {"w1": model.linear.weight.numpy(), "w2":
model.linear.bias.numpy()}
+ verify_model(model, input_info, binding, expected1)
+
+ model = Dense2()
+ binding = {"w1": model.linear.weight.numpy()}
+ verify_model(model, input_info, binding, expected2)
+
+ # matmul
+ class MatMul1(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return torch.matmul(x, y)
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ input_1: R.Tensor((10, 10), dtype="float32"),
+ input_2: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.matmul(
+ input_1, input_2, out_dtype="float32"
+ )
+ gv: R.Tensor((10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(
+ MatMul1(),
+ [([10, 10], "float32"), ([10, 10], "float32")],
+ {},
+ expected3,
+ )
+
+
[email protected]_gpu
+def test_relu():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+
+ class ReLU0(Module):
+ def __init__(self):
+ super().__init__()
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, input):
+ return self.relu(input)
+
+ class ReLU1(Module):
+ def forward(self, input):
+ return torch.nn.functional.relu(input)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ input_1: R.Tensor((10, 10), dtype="float32")
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.nn.relu(input_1)
+ gv: R.Tensor((10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ input_info = [([10, 10], "float32")]
+ verify_model(ReLU0(), input_info, {}, expected)
+ verify_model(ReLU1(), input_info, {}, expected)
+
+
[email protected]_gpu
+def test_relu6():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+
+ class ReLU6(Module):
+ def __init__(self):
+ super().__init__()
+ self.relu6 = torch.nn.ReLU6()
+
+ def forward(self, input):
+ return self.relu6(input)
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(input: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10,
10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.clip(input, 0, 6)
+ gv: R.Tensor((10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ input_info = [([10, 10], "float32")]
+ verify_model(ReLU6(), input_info, {}, expected)
+
+
[email protected]_gpu
+def test_maxpool2d():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class MaxPool2d(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1])
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.max_pool2d(
+ input_1,
+ pool_size=[1, 1],
+ strides=[1, 1],
+ dilation=[1, 1],
+ padding=[0, 0, 0, 0],
+ layout="NCHW",
+ out_layout="NCHW",
+ )
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class MaxPool2d2(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3])
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 4, 4), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.max_pool2d(
+ input_1,
+ pool_size=[2, 2],
+ strides=[2, 2],
+ dilation=[2, 3],
+ padding=[0, 0, 0, 0],
+ layout="NCHW",
+ out_layout="NCHW",
+ )
+ gv: R.Tensor((1, 3, 4, 4), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class MaxPool2d3(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2,
stride=2)
+
+ def forward(self, input):
+ return self.pool(input)
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 6, 6), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 6, 6), dtype="float32") = R.nn.max_pool2d(
+ input_1,
+ pool_size=[4, 4],
+ strides=[2, 2],
+ dilation=[1, 1],
+ padding=[2, 2, 2, 2],
+ layout="NCHW",
+ out_layout="NCHW",
+ )
+ gv: R.Tensor((1, 3, 6, 6), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(MaxPool2d(), input_info, {}, expected1)
+ verify_model(MaxPool2d2(), input_info, {}, expected2)
+ verify_model(MaxPool2d3(), input_info, {}, expected3)
+
+
[email protected]_gpu
+def test_adaptive_avgpool2d():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class AdaptiveAvgPool2d0(Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = torch.nn.AdaptiveAvgPool2d([10, 10])
+
+ def forward(self, input):
+ return self.pool(input)
+
+ class AdaptiveAvgPool2d1(Module):
+ def forward(self, input):
+ return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10])
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.adaptive_avg_pool2d(
+ input_1, output_size=[10, 10], layout="NCHW",
out_layout="NCHW"
+ )
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(AdaptiveAvgPool2d0(), input_info, {}, expected1)
+ verify_model(AdaptiveAvgPool2d1(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_flatten():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Flatten(Module):
+ def __init__(self):
+ super().__init__()
+ self.f = torch.nn.Flatten(2, -1)
+
+ def forward(self, input):
+ return self.f(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 100), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 100), dtype="float32") =
R.reshape(input_1, (1, 3, 100))
+ gv: R.Tensor((1, 3, 100), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ # call_module
+ verify_model(Flatten(), input_info, {}, expected1)
+ # call_method
+ verify_model(torch.nn.Flatten(2, -1), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_batchnorm2d():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class BatchNorm2d(Module):
+ def __init__(self):
+ super().__init__()
+ self.bn = torch.nn.BatchNorm2d(3)
+
+ def forward(self, input):
+ return self.bn(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((3,), dtype="float32"),
+ w2: R.Tensor((3,), dtype="float32"),
+ w3: R.Tensor((3,), dtype="float32"),
+ w4: R.Tensor((3,), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ R.Tensor((3,), dtype="float32"),
+ ) = R.nn.batch_norm(
+ input_1,
+ w1,
+ w2,
+ w3,
+ w4,
+ axis=1,
+ epsilon=1e-05,
+ center=True,
+ scale=True,
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ model = BatchNorm2d()
+ binding = {
+ "w1": model.bn.weight.numpy(),
+ "w2": model.bn.bias.numpy(),
+ "w3": model.bn.running_mean.numpy(),
+ "w4": model.bn.running_var.numpy(),
+ }
+ verify_model(BatchNorm2d(), input_info, binding, expected1)
+
+
[email protected]_gpu
+def test_embedding():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([4], "int64")]
+
+ class Embedding(Module):
+ def __init__(self):
+ super().__init__()
+ self.embedding = torch.nn.Embedding(10, 3)
+
+ def forward(self, input):
+ return self.embedding(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3),
dtype="float32")
+ ) -> R.Tensor((4, 3), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((4,), dtype="int32") = R.astype(input_1,
dtype="int32")
+ lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0)
+ gv: R.Tensor((4, 3), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ model = Embedding()
+ binding = {"w1": model.embedding.weight.numpy()}
+ verify_model(model, input_info, binding, expected1)
+
+
[email protected]_gpu
+def test_dropout():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Dropout(Module):
+ def __init__(self):
+ super().__init__()
+ self.dropout = torch.nn.Dropout(0.5)
+
+ def forward(self, input):
+ return self.dropout(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = input_1
+ R.output(gv)
+ return gv
+
+ verify_model(Dropout(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_layernorm():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class LayerNorm(Module):
+ def __init__(self):
+ super().__init__()
+ self.ln = torch.nn.LayerNorm((10, 10))
+
+ def forward(self, input):
+ return self.ln(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((10, 10), dtype="float32"),
+ w2: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.layer_norm(
+ input_1,
+ w1,
+ w2,
+ axes=[-2, -1],
+ epsilon=1e-05,
+ center=True,
+ scale=True,
+ )
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ model = LayerNorm()
+ binding = {
+ "w1": model.ln.weight.numpy(),
+ "w2": model.ln.bias.numpy(),
+ }
+ verify_model(LayerNorm(), input_info, binding, expected1)
+
+
[email protected]_gpu
+def test_silu():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class SiLU(Module):
+ def __init__(self):
+ super().__init__()
+ self.silu = torch.nn.SiLU()
+
+ def forward(self, input):
+ return self.silu(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.silu(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(SiLU(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_groupnorm():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class GroupNorm(Module):
+ def __init__(self):
+ super().__init__()
+ self.gn = torch.nn.GroupNorm(3, 3)
+
+ def forward(self, input):
+ return self.gn(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ w1: R.Tensor((3,), dtype="float32"),
+ w2: R.Tensor((3,), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.reshape(
+ input_1, (1, 3, 1, 10, 10)
+ )
+ lv1: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.mean(
+ lv, axis=[2, 3, 4], keepdims=True
+ )
+ lv2: R.Tensor((1, 3, 1, 10, 10), dtype="float32") =
R.subtract(lv, lv1)
+ lv3: R.Tensor((1, 3, 1, 10, 10), dtype="float32") =
R.multiply(lv2, lv2)
+ lv4: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sum(
+ lv3, axis=[2, 3, 4], keepdims=True
+ )
+ lv5: R.Tensor((1, 3, 1, 1, 1), dtype="float32") =
R.divide(lv4, R.const(100.0))
+ lv6: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.add(lv5,
R.const(1e-05))
+ lv7: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sqrt(lv6)
+ lv8: R.Tensor((1, 3, 1, 10, 10), dtype="float32") =
R.divide(lv2, lv7)
+ lv9: R.Tensor((1, 3, 1, 1, 1), dtype="float32") =
R.reshape(w1, (1, 3, 1, 1, 1))
+ lv10: R.Tensor((1, 3, 1, 1, 1), dtype="float32") =
R.reshape(w2, (1, 3, 1, 1, 1))
+ lv11: R.Tensor((1, 3, 1, 10, 10), dtype="float32") =
R.multiply(lv8, lv9)
+ lv12: R.Tensor((1, 3, 1, 10, 10), dtype="float32") =
R.add(lv11, lv10)
+ lv13: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.reshape(lv12, (1, 3, 10, 10))
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv13
+ R.output(gv)
+ return gv
+
+ model = GroupNorm()
+ binding = {
+ "w1": model.gn.weight.numpy(),
+ "w2": model.gn.bias.numpy(),
+ }
+ verify_model(model, input_info, binding, expected1)
+
+
[email protected]_gpu
+def test_softmax():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Softmax(Module):
+ def __init__(self):
+ super().__init__()
+ self.sm = torch.nn.Softmax(dim=1)
+
+ def forward(self, input):
+ return self.sm(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.softmax(input_1, axis=1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Softmax(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_binary():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
+ input_info2 = [([1, 3, 10, 10], "float32")]
+ # Add
+ class Add1(Module):
+ def forward(self, lhs, rhs):
+ return lhs + rhs
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ lhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs, rhs)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class Add2(Module):
+ def forward(self, lhs):
+ return lhs + 1.0
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs_1,
R.const(1.0))
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Add1(), input_info1, {}, expected1)
+ verify_model(Add2(), input_info2, {}, expected2)
+
+ # Sub
+ class Sub1(Module):
+ def forward(self, lhs, rhs):
+ return lhs - rhs
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.subtract(lhs_1, rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class Sub2(Module):
+ def forward(self, lhs):
+ return lhs - 1.0
+
+ @tvm.script.ir_module
+ class expected4:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.subtract(lhs_1, R.const(1.0))
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Sub1(), input_info1, {}, expected3)
+ verify_model(Sub2(), input_info2, {}, expected4)
+
+ # Mul
+ class Mul1(Module):
+ def forward(self, lhs, rhs):
+ return lhs * rhs
+
+ @tvm.script.ir_module
+ class expected5:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(lhs_1, rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class Mul2(Module):
+ def forward(self, lhs):
+ return lhs * 1.0
+
+ @tvm.script.ir_module
+ class expected6:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(lhs_1, R.const(1.0))
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Mul1(), input_info1, {}, expected5)
+ verify_model(Mul2(), input_info2, {}, expected6)
+
+ # True div
+ class TrueDiv1(Module):
+ def forward(self, lhs, rhs):
+ return lhs / rhs
+
+ @tvm.script.ir_module
+ class expected7:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.divide(lhs_1, rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class TrueDiv2(Module):
+ def forward(self, lhs):
+ return lhs / 1.0
+
+ @tvm.script.ir_module
+ class expected8:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.divide(lhs_1, R.const(1.0))
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(TrueDiv1(), input_info1, {}, expected7)
+ verify_model(TrueDiv2(), input_info2, {}, expected8)
+
+ # Floor div
+ class FloorDiv1(Module):
+ def forward(self, lhs, rhs):
+ return lhs // rhs
+
+ @tvm.script.ir_module
+ class expected9:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.floor_divide(lhs_1, rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class FloorDiv2(Module):
+ def forward(self, lhs):
+ return lhs // 1.0
+
+ @tvm.script.ir_module
+ class expected10:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.floor_divide(lhs_1, R.const(1.0))
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(FloorDiv1(), input_info1, {}, expected9)
+ verify_model(FloorDiv2(), input_info2, {}, expected10)
+
+ # LT
+ class LT1(Module):
+ def forward(self, lhs, rhs):
+ return lhs < rhs
+
+ @tvm.script.ir_module
+ class expected11:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1,
rhs_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+ R.output(gv)
+ return gv
+
+ class LT2(Module):
+ def forward(self, lhs):
+ return lhs < 1.0
+
+ @tvm.script.ir_module
+ class expected12:
+ @R.function
+ def main(
+ lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="bool"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1,
R.const(1.0))
+ gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(LT1(), input_info1, {}, expected11)
+ verify_model(LT2(), input_info2, {}, expected12)
+
+
[email protected]_gpu
+def test_size():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Size(Module):
+ def forward(self, input):
+ return input.size()
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) ->
R.Shape([1, 3, 10, 10]):
+ # block 0
+ with R.dataflow():
+ gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10])
+ R.output(gv)
+ return gv
+
+ verify_model(Size(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_unsqueeze():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Unsqueeze1(Module):
+ def forward(self, input):
+ return input.unsqueeze(1)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") =
R.expand_dims(input_1, 1)
+ gv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ class Unsqueeze2(Module):
+ def forward(self, input):
+ return input.unsqueeze(-1)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10, 1), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") =
R.expand_dims(input_1, -1)
+ gv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Unsqueeze1(), input_info, {}, expected1)
+ verify_model(Unsqueeze2(), input_info, {}, expected2)
+
+
[email protected]_gpu
+def test_getattr():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class GetAttr1(Module):
+ def forward(self, input):
+ return input.shape
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) ->
R.Shape([1, 3, 10, 10]):
+ # block 0
+ with R.dataflow():
+ gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10])
+ R.output(gv)
+ return gv
+
+ verify_model(GetAttr1(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_getitem():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Slice1(Module):
+ def forward(self, x):
+ return x[0, 1::2, :, :3]
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 1, 10, 3), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 1, 10, 3), dtype="float32") = R.strided_slice(
+ x,
+ axes=[0, 1, 2, 3],
+ begin=[0, 1, 0, 0],
+ end=[1, T.int64(3), T.int64(10), 3],
+ strides=[1, 2, 1, 1],
+ )
+ lv1: R.Tensor((1, 1, 10, 3), dtype="float32") = R.reshape(lv,
(1, 1, 10, 3))
+ gv: R.Tensor((1, 1, 10, 3), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ verify_model(Slice1(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_unary():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ # sin
+ class Sin(Module):
+ def forward(self, input):
+ return torch.sin(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Sin(), input_info, {}, expected1)
+
+ # cos
+ class Cos(Module):
+ def forward(self, input):
+ return torch.cos(input)
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Cos(), input_info, {}, expected2)
+
+ # sqrt
+ class Sqrt(Module):
+ def forward(self, input):
+ return torch.sqrt(input)
+
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Sqrt(), input_info, {}, expected3)
+
+
[email protected]_gpu
+def test_gelu():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Gelu(Module):
+ def forward(self, input):
+ return torch.nn.functional.gelu(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.gelu(input_1)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Gelu(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_clamp():
+ import torch
+ from torch import fx
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Clamp(Module):
+ def forward(self, input):
+ return torch.clamp(input, min=0.1, max=0.5)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.clip(input_1, 0.1, 0.5)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Clamp(), input_info, {}, expected1)
+
+ from tvm.relax.frontend.torch import from_fx
+
+ with pytest.raises(
+ ValueError, match="TVM only supports constant max value for
torch.clamp/clip"
+ ):
+
+ class Clamp_Error(Module):
+ def forward(self, input):
+ return torch.clamp(input, min=0.5, max=None)
+
+ gm = fx.symbolic_trace(Clamp_Error())
+ from_fx(gm, input_info)
+
+ with pytest.raises(
+ ValueError, match="TVM only supports constant min value for
torch.clamp/clip"
+ ):
+
+ class Clamp_Error(Module):
+ def forward(self, input):
+ return torch.clamp(input, min=input, max=input)
+
+ gm = fx.symbolic_trace(Clamp_Error())
+ from_fx(gm, input_info)
+
+
[email protected]_gpu
+def test_interpolate():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Interpolate(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(input, size=(5, 5))
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 5, 5), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 5, 5), dtype="float32") = R.image.resize2d(
+ input_1,
+ (5, 5),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000],
+ layout="NCHW",
+ method="nearest_neighbor",
+ coordinate_transformation_mode="asymmetric",
+ rounding_method="round",
+ cubic_alpha=-0.5,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 3, 5, 5), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Interpolate(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_addmm():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [
+ ([10, 10], "float32"),
+ ([10, 10], "float32"),
+ ([10, 10], "float32"),
+ ]
+
+ class Addmm(Module):
+ def forward(self, x1, x2, x3):
+ return torch.addmm(x1, x2, x3)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x1: R.Tensor((10, 10), dtype="float32"),
+ x2: R.Tensor((10, 10), dtype="float32"),
+ x3: R.Tensor((10, 10), dtype="float32"),
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3,
out_dtype="float32")
+ lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv)
+ gv: R.Tensor((10, 10), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ verify_model(Addmm(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_split():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 3, 10, 10], "float32")]
+
+ class Split(Module):
+ def forward(self, input):
+ return torch.split(input, 1, dim=1)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ ):
+ # block 0
+ with R.dataflow():
+ lv: R.Tuple(
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ ) = R.split(input_1, indices_or_sections=3, axis=1)
+ gv: R.Tuple(
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ R.Tensor((1, 1, 10, 10), dtype="float32"),
+ ) = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Split(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_tril():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([10, 10], "float32")]
+
+ class Tril(Module):
+ def forward(self, input):
+ return torch.tril(input, 1)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ input_1: R.Tensor((10, 10), dtype="float32")
+ ) -> R.Tensor((10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1)
+ gv: R.Tensor((10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Tril(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_new_ones():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 2, 3], "float32")]
+
+ class NewOnes(Module):
+ def forward(self, x):
+ return x.new_ones(1, 2, 3)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2,
3), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 2, 3), dtype="float32") = R.full(
+ (1, 2, 3), R.const(1, "float32"), dtype="float32"
+ )
+ gv: R.Tensor((1, 2, 3), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(NewOnes(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_expand():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 2, 3, 4], "float32")]
+
+ class Expand(Module):
+ def forward(self, x):
+ return x.expand(4, 2, 3, 4)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tensor((4, 2, 3, 4), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((4, 2, 3, 4), dtype="float32") =
R.broadcast_to(x, (4, 2, 3, 4))
+ gv: R.Tensor((4, 2, 3, 4), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Expand(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_reduce():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 2, 3, 4], "float32")]
+
+ # sum
+ class Sum(Module):
+ def forward(self, x):
+ return torch.sum(x, (2, 1))
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tensor((1, 4), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2,
1], keepdims=False)
+ gv: R.Tensor((1, 4), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Sum(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_to():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 2, 3, 4], "float32")]
+
+ # float
+ class ToFloat(Module):
+ def forward(self, x):
+ return x.float()
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tensor((1, 2, 3, 4), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x,
dtype="float32")
+ gv: R.Tensor((1, 2, 3, 4), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(ToFloat(), input_info, {}, expected1)
+
+ # half
+ class ToHalf(Module):
+ def forward(self, x):
+ return x.half()
+
+ @tvm.script.ir_module
+ class expected2:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tensor((1, 2, 3, 4), dtype="float16"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x,
dtype="float16")
+ gv: R.Tensor((1, 2, 3, 4), dtype="float16") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(ToHalf(), input_info, {}, expected2)
+
+
[email protected]_gpu
+def test_permute():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 2, 3, 4], "float32")]
+
+ class Permute(Module):
+ def forward(self, x):
+ return x.permute(0, 3, 2, 1)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tensor((1, 4, 3, 2), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 4, 3, 2), dtype="float32") =
R.permute_dims(x, axes=[0, 3, 2, 1])
+ gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Permute(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_reshape():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 2, 3, 4], "float32")]
+
+ class Reshape(Module):
+ def forward(self, x):
+ return x.reshape(2, 12)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2,
12), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
+ gv: R.Tensor((2, 12), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Reshape(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_transpose():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 2, 3, 4], "float32")]
+
+ class Transpose(Module):
+ def forward(self, x):
+ return x.transpose(1, 3)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 3, 4), dtype="float32")
+ ) -> R.Tensor((1, 4, 3, 2), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 4, 3, 2), dtype="float32") =
R.permute_dims(x, axes=[0, 3, 2, 1])
+ gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Transpose(), input_info, {}, expected1)
+
+
[email protected]_gpu
+def test_view():
+ import torch
+ from torch.nn import Module
+
+ torch.set_grad_enabled(False)
+ torch.random.manual_seed(0)
+
+ input_info = [([1, 2, 3, 4], "float32")]
+
+ class View(Module):
+ def forward(self, x):
+ return x.view(2, 12)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2,
12), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12))
+ gv: R.Tensor((2, 12), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(View(), input_info, {}, expected1)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_pipeline.py
b/tests/python/relax/test_pipeline.py
new file mode 100644
index 0000000000..6d6704ae97
--- /dev/null
+++ b/tests/python/relax/test_pipeline.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm
+from tvm import relax
+from tvm.script import relax as R
+
+
+def test_pipeline_compile():
+ pipeline = relax.get_pipeline()
+
+ @tvm.script.ir_module
+ class Mod:
+ @R.function
+ def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4),
"float32")):
+ lv0 = R.add(x, y)
+ return lv0
+
+ mod = Mod
+ mod = pipeline(mod)
+ target = tvm.target.Target("llvm", host="llvm")
+
+ ex = relax.vm.build(mod, target)
+ x_np = np.random.rand(3, 4).astype(np.float32)
+ y_np = np.random.rand(3, 4).astype(np.float32)
+ x = tvm.nd.array(x_np)
+ y = tvm.nd.array(y_np)
+
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ z = vm["main"](x, y)
+ tvm.testing.assert_allclose(z.numpy(), x_np + y_np, rtol=1e-7, atol=1e-7)