This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 405ef9dbf67beb5febe499e3e4f1c2c067a6e237 Author: Ruihang Lai <[email protected]> AuthorDate: Sat Feb 18 11:19:48 2023 -0500 [Unity] Initial PyTorch Frontend (#14037) [Unity] Initial PyTorch Frontend This PR introduces initial pytorch frontend components of Relax, including - a FX translator that translates a Torch FX graph module to an TVM IRModule, - a Relax-backend of Torch Dynamo, which brings the mechanism to build PyTorch model using Relax compilation pipeline, - a pipeline prototype that contains the collection of pre-defined pipelines that optimizes and lower IRModule before passing to minimum build. Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Tianqi Chen <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> --- python/tvm/relax/__init__.py | 3 + python/tvm/relax/frontend/__init__.py | 19 + python/tvm/relax/frontend/torch/__init__.py | 21 + python/tvm/relax/frontend/torch/dynamo.py | 156 ++ python/tvm/relax/frontend/torch/fx_translator.py | 820 ++++++++++ python/tvm/relax/pipeline.py | 84 ++ tests/python/relax/test_frontend_dynamo.py | 198 +++ tests/python/relax/test_frontend_from_fx.py | 1729 ++++++++++++++++++++++ tests/python/relax/test_pipeline.py | 45 + 9 files changed, 3075 insertions(+) diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index cfcf7876dc..33a9c2eece 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -73,6 +73,9 @@ from .struct_info import ( FuncStructInfo, ) +# pipeline +from .pipeline import get_pipeline + # Import submodules in the last to avoid dependency from . import exec_builder from . import expr diff --git a/python/tvm/relax/frontend/__init__.py b/python/tvm/relax/frontend/__init__.py new file mode 100644 index 0000000000..6c9c188aaa --- /dev/null +++ b/python/tvm/relax/frontend/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Frontends for constructing Relax programs, with the model importers +""" diff --git a/python/tvm/relax/frontend/torch/__init__.py b/python/tvm/relax/frontend/torch/__init__.py new file mode 100644 index 0000000000..55da5a456d --- /dev/null +++ b/python/tvm/relax/frontend/torch/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +PyTorch Frontends for constructing Relax programs, with the model importers +""" +from .fx_translator import from_fx +from .dynamo import relax_dynamo, dynamo_capture_subgraphs diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py new file mode 100644 index 0000000000..94de73a431 --- /dev/null +++ b/python/tvm/relax/frontend/torch/dynamo.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, missing-function-docstring, not-callable +# pylint: disable=import-outside-toplevel, unused-argument +# mypy: ignore-errors +"""PyTorch Dynamo backend of Relax.""" +import functools +from typing import Optional + +import tvm +from tvm.relax.vm import build as relax_build +from tvm.relax.frontend.torch.fx_translator import from_fx + + +def device_from_inputs(example_inputs): + for x in example_inputs: + if hasattr(x, "device"): + return x.device + return None + + +def relax_dynamo(pipeline: Optional[tvm.transform.Pass] = None): + """A helper function to create a relax backend. + + Parameters + ---------- + pipeline : Optional[tvm.transform.Pass] + The pipeline to be applied to the relax module before sent to build. + + Returns + ------- + backend : Callable[[torch.fx.GraphModule, List[torch.Tensor]], Callable] + The relax dynamo backend. + """ + + def _relax_backend(graph_module, example_inputs): + import torch # type: ignore[import] + + assert isinstance(graph_module, torch.fx.GraphModule) + + def to_torch_tensor(nd_tensor): + """A helper function to transfer a NDArray to torch.tensor.""" + if isinstance(nd_tensor, tvm.nd.NDArray): + return torch.from_numpy(nd_tensor.numpy()) + elif isinstance(nd_tensor, tvm.ir.Array): + return tuple(to_torch_tensor(x) for x in nd_tensor) + else: + raise ValueError(f"Unsupported type {type(nd_tensor)}") + + def to_tvm_tensor(torch_tensor): + """A helper function to transfer a torch.tensor to NDArray.""" + if not isinstance(torch_tensor, torch._subclasses.fake_tensor.FakeTensor): + return tvm.nd.array(torch_tensor.numpy()) + # Fake Tensor + real_tensor = torch.randn(torch_tensor.shape, dtype=torch_tensor.dtype) + return tvm.nd.array(real_tensor.numpy()) + + device = device_from_inputs(example_inputs) + input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs] + mod = from_fx(graph_module, input_info) + + if device.type == "cuda": + dev = tvm.cuda(device.index) + target = tvm.target.cuda() + else: + dev = tvm.cpu(0) + target = tvm.target.Target(llvm_target()) + + # invoke optimization pipeline. + if pipeline is None: + # get default pipeline + seq = tvm.relax.get_pipeline() + elif isinstance(pipeline, str): + # lookup by name + seq = tvm.relax.get_pipeline(pipeline) + else: + seq = pipeline + + mod = mod.with_attr("target", target) + mod = seq(mod) + + ex = relax_build(mod, target=target) + + vm = tvm.relax.vm.VirtualMachine(exec=ex.mod, device=dev) + + def exec_tvm(*i_args): + args = [a.contiguous() for a in i_args] + vm_args = list() + for arg in args: + if arg.dim() != 0: + if arg.requires_grad: + arg = arg.detach() + vm_args.append(to_tvm_tensor(arg)) + outputs = vm["main"](*vm_args) + return to_torch_tensor(outputs) + + return exec_tvm + + return _relax_backend + + +def dynamo_capture_subgraphs(model, *params) -> tvm.ir.IRModule: + """Capture subgraphs of the PyTorch model using torch.compile into an IRModule. + + Parameters + ---------- + model : torch.nn.Module + The PyTorch model to be captured. + + params : List[torch.Tensor] + The parameters of the PyTorch model. + + Returns + ------- + mod : tvm.ir.IRModule + The IRModule that contains captured subgraphs. + """ + import torch # type: ignore[import] + from torch import fx # type: ignore[import] + from torch import _dynamo as dynamo # type: ignore[import] + + mod = tvm.IRModule() + + def _capture(graph_module: fx.GraphModule, example_inputs): + assert isinstance(graph_module, torch.fx.GraphModule) + input_info = [(tuple(tensor.shape), str(tensor.dtype)) for tensor in example_inputs] + subgraph = from_fx(graph_module, input_info) + mod["subgraph_" + str(len(mod.get_global_vars()))] = subgraph["main"] + return graph_module.forward + + dynamo.reset() + compiled_model = torch.compile(model, backend=_capture) + compiled_model(*params) + return mod + + [email protected]_cache(None) +def llvm_target(): + if "avx512" in open("/proc/cpuinfo").read(): + return "llvm -mcpu=skylake-avx512" + return "llvm -mcpu=core-avx2" diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py new file mode 100644 index 0000000000..582f2edbcf --- /dev/null +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -0,0 +1,820 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck +# pylint: disable=import-outside-toplevel +"""PyTorch FX frontend of Relax.""" +from typing import Callable, Dict, List, Tuple, Union +from functools import reduce + +import tvm +from tvm import relax + + +class TorchFXImporter: + """An importer from PyTorch FX to Relax.""" + + import torch # type: ignore + from torch import fx + + def __init__(self) -> None: + import torch # type: ignore + from torch import fx + + self.env: Dict[fx.node.Node, relax.Expr] = {} + self.params: Dict[torch.Tensor, relax.Constant] = {} + self.named_modules: Dict[str, torch.Module] = None + self.block_builder: relax.BlockBuilder = None + self.create_convert_map() + + ########## Utilities ########## + @staticmethod + def _fetch_attr(model, target: str): + import torch # type: ignore + + target_atoms = target.split(".") + attr_itr = model + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced non existing target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + if isinstance(attr_itr, torch.Tensor): + return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr) + return attr_itr + + @staticmethod + def _convert_data_type(input_type): + """converts the PyTorch scalar type input_type to a TVM dtype.""" + import torch # type: ignore + + input_type = input_type.lower() + if input_type in ["float", "float32", "torch.float32", torch.float32]: + return "float32" + elif input_type in ["float16", "torch.float16", torch.float16]: + return "float16" + elif input_type in ["int64", "torch.int64", torch.int64]: + return "int64" + else: + raise NotImplementedError("input_type {} is not handled yet".format(input_type)) + + @staticmethod + def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var: + tensor = tensor.detach().cpu() + shape = tensor.data.shape + dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype)) + return relax.const(tensor.data.numpy(), relax.TensorStructInfo(shape, dtype)) + + @staticmethod + def shape_of(tensor): + """Get the shape of a tensor.""" + import torch # type: ignore + + if isinstance(tensor, relax.Expr): + if not isinstance(tensor.struct_info, relax.TensorStructInfo): + raise TypeError("The input Expr of shape_of should be a Tensor") + return tensor.struct_info.shape + elif isinstance(tensor, torch.Tensor): + return tensor.shape + raise ValueError("Unsupported type: {}".format(type(tensor))) + + def retrieve_args(self, node): + return self._retrieve_args(node.args) + + def _retrieve_args(self, node): + from torch import fx + + if isinstance(node, fx.node.Node): + return self.env[node] + elif isinstance(node, tuple): + return tuple(self._retrieve_args(x) for x in node) + elif isinstance(node, list): + return [self._retrieve_args(x) for x in node] + elif isinstance(node, dict): + return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} + else: + return node + + @staticmethod + def _promote_binary_op_args(lhs, rhs): + if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): + return lhs, rhs + elif isinstance(lhs, relax.Expr): + assert isinstance(lhs.struct_info, relax.TensorStructInfo) + return lhs, relax.const(rhs, lhs.struct_info.dtype) + elif isinstance(rhs, relax.Expr): + assert isinstance(rhs.struct_info, relax.TensorStructInfo) + return relax.const(lhs, rhs.struct_info.dtype), rhs + else: + assert False + + def _call_binary_op(self, op, lhs, rhs): + lhs, rhs = TorchFXImporter._promote_binary_op_args(lhs, rhs) + return self.block_builder.emit(op(lhs, rhs)) + + ########## Arithmetic ########## + + def _cos(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.cos(self.env[node.args[0]])) + + def _sin(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.sin(self.env[node.args[0]])) + + def _sqrt(self, node: fx.node.Node) -> relax.Expr: + arg = self.env[node.args[0]] + if isinstance(arg, (int, float)): + arg = relax.const(arg, "float32") + return self.block_builder.emit(relax.op.sqrt(arg)) + + def _add(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.add, lhs, rhs) + return lhs + rhs + + def _floordiv(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.floor_divide, lhs, rhs) + return lhs // rhs + + def _mul(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.multiply, lhs, rhs) + return lhs * rhs + + def _sub(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.subtract, lhs, rhs) + return lhs - rhs + + def _truediv(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): + return self._call_binary_op(relax.op.divide, lhs, rhs) + return lhs / rhs + + def _clamp(self, node: fx.node.Node) -> relax.Expr: + args = self.retrieve_args(node) + a_min = node.kwargs["min"] + a_max = node.kwargs["max"] + if not isinstance(a_min, (int, float)): + raise ValueError( + f"TVM only supports constant min value for torch.clamp/clip, " + f"but got {a_min} with type {type(a_min)}" + ) + if not isinstance(a_max, (int, float)): + raise ValueError( + f"TVM only supports constant max value for torch.clamp/clip, " + f"but got {a_max} with type {type(a_max)}" + ) + return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max)) + + ########## Compare ########## + + def _lt(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + return self._call_binary_op(relax.op.less, lhs, rhs) + + ########## Creation ########## + + def _tril(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + return self.block_builder.emit(relax.op.create.tril(x, k)) + + def _new_ones(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, self_var.struct_info.dtype), + self_var.struct_info.dtype, + ) + ) + + ########## Statistical ########## + + def _sum(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + if len(args) == 1: + return self.block_builder.emit(relax.op.sum(args[0])) + return self.block_builder.emit(relax.op.sum(args[0], args[1])) + + ########## DataType ########## + + def _float(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) + + def _half(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + + def _type(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.astype(args[0], args[1])) + + ########## Linear Algebra ########## + + def _matmul_impl(self, a: relax.Expr, b: relax.Expr): + return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) + + def _matmul(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + res = self._matmul_impl( + args[0], + args[1], + ) + return res + + def _addmm(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + y = self.env[node.args[1]] + z = self.env[node.args[2]] + matmul = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) + return self.block_builder.emit(relax.op.add(x, matmul)) + + ########## Manipulation ########## + + def _cat(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.concat(args[0], axis=node.kwargs["dim"])) + + def _expand(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.broadcast_to(args[0], args[1:])) + + def _flatten(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + start_dim = module.start_dim + end_dim = module.end_dim + else: + start_dim = node.args[1] if len(node.args) >= 2 else 0 + end_dim = node.args[2] if len(node.args) == 3 else -1 + shape = self.shape_of(x) + start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim + end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim + flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) + new_shape = ( + [shape[i] for i in range(0, start_dim)] + + [flattened] + + [shape[i] for i in range(end_dim + 1, len(shape))] + ) + return self.block_builder.emit(relax.op.reshape(x, new_shape)) + + def _permute(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) + + def _reshape(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) + + def _split(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + split_size = node.args[1] + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + else: + dim = 0 + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + return self.block_builder.emit(relax.op.split(x, n_section, dim)) + + def _transpose(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + full_idx = list(range(len(self.shape_of(args[0])))) + full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] + return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) + + ########## Neural Network ########## + + def _linear(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = None if module.bias is None else self.params[module.bias] + return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + + def _conv2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + + conv2d = self.block_builder.emit( + relax.op.nn.conv2d( + x, + weight, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if module.bias is None: + return conv2d + + bias = self.params[module.bias] + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + + return self.block_builder.emit(relax.op.add(conv2d, bias)) + + def _max_pool2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + kernel = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + ceil_mode = module.ceil_mode + else: + nargs = len(node.args) + kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] + stride = node.args[2] if nargs > 2 else node.kwargs["stride"] + padding = node.args[3] if nargs > 3 else node.kwargs["padding"] + dilation = node.args[4] if nargs > 4 else node.kwargs["dilation"] + ceil_mode = node.args[5] if nargs > 5 else node.kwargs["ceil_mode"] + + stride = kernel if stride is None else stride + + return self.block_builder.emit( + relax.op.nn.max_pool2d( + x, + pool_size=kernel, + strides=stride, + padding=padding, + dilation=dilation, + layout="NCHW", + ceil_mode=ceil_mode, + ) + ) + + def _adaptive_avg_pool2d(self, is_module: bool) -> Callable: + from torch import fx + + def _impl(node: fx.node.Node) -> relax.Var: + if is_module: + module = self.named_modules[node.target] + x = self.env[node.args[0]] + output_size = module.output_size + else: + x = self.env[node.args[0]] + output_size = node.args[1] + return self.block_builder.emit( + relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") + ) + + return _impl + + def _softmax(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + dim = module.dim + else: + nargs = len(node.args) + dim = node.args[1] if nargs > 1 else node.kwargs["dim"] + assert dim is not None + return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + + def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params[module.bias] + dtype = self._convert_data_type(str(module.running_mean.dtype)) + running_mean = relax.const(module.running_mean.cpu().detach().numpy(), dtype) + running_var = relax.const(module.running_var.cpu().detach().numpy(), dtype) + eps = module.eps + + res_tuple = self.block_builder.emit( + relax.op.nn.batch_norm( + x, + weight, + bias, + running_mean, + running_var, + axis=1, + epsilon=eps, + ) + ) + + return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) + + def _layer_norm(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + + if module.elementwise_affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.normalized_shape), x.checked_type) + beta = relax.const(torch.zeros_like(module.normalized_shape), x.checked_type) + dim_num = len(module.normalized_shape) + axes = list(range(-dim_num, 0)) + + return self.block_builder.emit( + relax.op.nn.layer_norm( + x, + gamma, + beta, + axes=axes, + epsilon=module.eps, + ) + ) + + def _group_norm(self, node: fx.node.Node) -> relax.Var: + # torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, + # affine=True, device=None, dtype=None) + x = self.env[node.args[0]] + module = self.named_modules[node.target] + num_groups = module.num_groups + num_channels = module.num_channels + eps = module.eps + affine = module.affine + + shape = self.shape_of(x) + assert len(shape) == 4 + N, C, H, W = shape[0], shape[1], shape[2], shape[3] + assert C == num_channels + assert C % num_groups == 0 + grouped_x = self.block_builder.emit( + relax.op.reshape(x, [N, num_groups, C // num_groups, H, W]) + ) + mean_x = self.block_builder.emit(relax.op.mean(grouped_x, [2, 3, 4], keepdims=True)) + sub_x = self.block_builder.emit(relax.op.subtract(grouped_x, mean_x)) + square_x = self.block_builder.emit(relax.op.multiply(sub_x, sub_x)) + sum_square_x = self.block_builder.emit(relax.op.sum(square_x, [2, 3, 4], keepdims=True)) + var_x = self._call_binary_op(relax.op.divide, sum_square_x, (C // num_groups * H * W).value) + var_x_eps = self._call_binary_op(relax.op.add, var_x, eps) + std_x = self.block_builder.emit(relax.op.sqrt(var_x_eps)) + norm_x = self.block_builder.emit(relax.op.divide(sub_x, std_x)) + + if affine: + weight = self.params[module.weight] + bias = self.params[module.bias] + weight_reshape = self.block_builder.emit( + relax.op.reshape(weight, (1, num_groups, C // num_groups, 1, 1)) + ) + bias_reshape = self.block_builder.emit( + relax.op.reshape(bias, (1, num_groups, C // num_groups, 1, 1)) + ) + norm_x = self.block_builder.emit(relax.op.multiply(norm_x, weight_reshape)) + norm_x = self.block_builder.emit(relax.op.add(norm_x, bias_reshape)) + return self.block_builder.emit(relax.op.reshape(norm_x, (N, C, H, W))) + + def _embedding(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + x = self.block_builder.emit(relax.op.astype(x, "int32")) + return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + + def _interpolate(self, node: fx.node.Node) -> relax.Var: + # torch.nn.functional.interpolate( + # input, size=None, scale_factor=None, mode='nearest', align_corners=None, + # recompute_scale_factor=None, antialias=False) + # (TODO) this is a temporary implementation for interpolate that only considers NCHW layout + # it basically replicates the implementation in tvm.relay.frontend.pytorch + data = self.env[node.args[0]] + size = node.kwargs["size"] + scale_factor = node.kwargs["scale_factor"] + method = node.kwargs["mode"] + align_corners = node.kwargs["align_corners"] + recompute_scale_factor = node.kwargs["recompute_scale_factor"] + antialias = node.kwargs["antialias"] + + assert recompute_scale_factor is None + assert antialias is False + + if size is None: + shape = self.shape_of(data) + assert isinstance(shape, relax.ShapeExpr) + size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + + if method.startswith("nearest"): + method = "nearest_neighbor" + elif method[0:2] == "bi": + method = method[2:] + + if method == "nearest_neighbor": + coord_trans = "asymmetric" + elif align_corners: + coord_trans = "align_corners" + else: + coord_trans = "half_pixel" + + return self.block_builder.emit( + relax.op.image.resize2d( + data, size, layout="NCHW", method=method, coordinate_transformation_mode=coord_trans + ) + ) + + ########## Others ########## + + def _size(self, node: fx.node.Node) -> relax.Expr: + x = self.env[node.args[0]] + shape = self.shape_of(x) + if len(node.args) == 1: + assert isinstance(shape, relax.ShapeExpr) + return shape + assert len(node.args) == 2 + idx = node.args[1] + return self.shape_of(x)[idx].value + + def _getattr(self, node: fx.node.Node) -> relax.Var: + if isinstance(self.env[node.args[0]], relax.Expr): + if node.args[1] == "dtype": + return self.env[node.args[0]].struct_info.dtype + elif node.args[1] == "shape": + return self.shape_of(self.env[node.args[0]]) + return getattr(self.env[node.args[0]], node.args[1]) + + def _getitem(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)): + return x[node.args[1]] + elif isinstance(x, relax.Var): + if isinstance(x.struct_info, relax.TupleStructInfo): + return self.block_builder.emit(relax.TupleGetItem(x, node.args[1])) + + assert isinstance(x.struct_info, relax.TensorStructInfo) + begin = [] + end = [] + stride = [] + axes = [] + expand_dim = [] + i = 0 + shape = self.shape_of(x) + for index in node.args[1]: + if isinstance(index, int): + begin.append(index) + end.append(index + 1) + stride.append(1) + axes.append(i) + i = i + 1 + elif isinstance(index, slice): + begin.append(0 if index.start is None else index.start) + end.append(shape[i] if index.stop is None else index.stop) + stride.append(1 if index.step is None else index.step) + axes.append(i) + i = i + 1 + elif index is None: + expand_dim.append(i) + i = i + 1 + else: + raise ValueError("Unsupported index type: " + str(type(index))) + while i < len(shape): + begin.append(0) + end.append(shape[i]) + axes.append(i) + i = i + 1 + sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + sliced_shape = list(self.shape_of(sliced)) + for i in expand_dim: + sliced_shape.insert(i, 1) + return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) + else: + assert False + + def create_convert_map(self): + from torch import nn + from torch import fx + + self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.node.Node], relax.Var]] = { + # call_module + nn.Linear: self._linear, + nn.Conv2d: self._conv2d, + nn.MaxPool2d: self._max_pool2d, + nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True), + nn.Softmax: self._softmax, + nn.ReLU: lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), + nn.ReLU6: lambda node: self.block_builder.emit( + relax.op.clip(self.env[node.args[0]], 0, 6) + ), + nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), + nn.Flatten: self._flatten, + nn.BatchNorm2d: self._batch_norm_2d, + nn.LayerNorm: self._layer_norm, + nn.GroupNorm: self._group_norm, + nn.Dropout: lambda node: self.env[node.args[0]], + nn.modules.sparse.Embedding: self._embedding, + # call_function and call_method + "cos": self._cos, + "sin": self._sin, + "add": self._add, + "floordiv": self._floordiv, + "mul": self._mul, + "sub": self._sub, + "sqrt": self._sqrt, + "lt": self._lt, + "truediv": self._truediv, + "new_ones": self._new_ones, + "tril": self._tril, + "sum": self._sum, + "float": self._float, + "half": self._half, + "type": self._type, + "matmul": self._matmul, + "addmm": self._addmm, + "cat": self._cat, + "expand": self._expand, + "flatten": self._flatten, + "permute": self._permute, + "reshape": self._reshape, + "split": self._split, + "transpose": self._transpose, + "unsqueeze": lambda node: self.block_builder.emit( + relax.op.expand_dims(self.env[node.args[0]], node.args[1]) + ), + "view": self._reshape, + "softmax": self._softmax, + "clamp": self._clamp, + "relu": lambda node: self.block_builder.emit(relax.op.nn.relu(self.env[node.args[0]])), + "gelu": lambda node: self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])), + "interpolate": self._interpolate, + "size": self._size, + "getattr": self._getattr, + "getitem": self._getitem, + "contiguous": lambda node: self.env[node.args[0]], + "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), + } + + def from_fx(self, model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule: + """Convert a PyTorch FX GraphModule to a Relax program.""" + from torch import fx + + self.named_modules = dict(model.named_modules()) + + graph: fx.Graph = model.graph + # Create input variables. + inputs = list() + for idx, (shape, dtype) in enumerate(input_info): + inputs.append( + relax.Var( + f"inp_{idx}", relax.TensorStructInfo(shape, self._convert_data_type(dtype)) + ) + ) + + # Initialize the block builder with a function and a dataflow block. + self.block_builder = relax.BlockBuilder() + with self.block_builder.function(name="main", params=inputs.copy()): + output = None + with self.block_builder.dataflow(): + # Translate model parameters. + for _, param in model.named_parameters(): + shape = param.data.shape + dtype = self._convert_data_type(str(param.data.dtype)) + if dtype in ("float32", "float16"): + self.params[param] = relax.const( + param.data.cpu().numpy(), relax.TensorStructInfo(shape, dtype) + ) + else: + raise ValueError("Unsupported data type for model parameters: %s" % dtype) + # Translate the model. + for node in graph.nodes: + if node.op == "placeholder": + assert len(inputs) > 0, "Provided inputs is less than actual inputs" + self.env[node] = inputs.pop(0) + elif node.op == "output": + args = self.retrieve_args(node) + output = self.block_builder.emit_output(args[0]) + break + elif node.op == "get_attr": + self.env[node] = TorchFXImporter._fetch_attr(model, node.target) + elif node.op == "call_module": + module = self.named_modules[node.target] + assert ( + type(module) in self.convert_map + ), f"Unsupported module type {type(module)}" + self.env[node] = self.convert_map[type(module)](node) + elif node.op == "call_function": + func_name = node.name.rstrip("0123456789_") + assert ( + func_name in self.convert_map + ), f"Unsupported function type {func_name}" + self.env[node] = self.convert_map[func_name](node) + elif node.op == "call_method": + assert ( + node.target in self.convert_map + ), f"Unsupported function target {node.target}" + self.env[node] = self.convert_map[node.target](node) + else: + raise ValueError(f"Unsupported op {node.op}") + assert output is not None + self.block_builder.emit_func_output(output) + + return self.block_builder.get() + + +def from_fx(model, input_info: List[Tuple[Tuple[int], str]]) -> tvm.IRModule: + """Convert a PyTorch FX GraphModule to a Relax program + + Parameters + ---------- + model : fx.GraphModule + The PyTorch FX GraphModule to convert. + + input_info : List[Tuple[Tuple[int], str]] + A list of shapes and data types of input tensors. + + Returns + ------- + module : tvm.IRModule + The converted Relax program. + + Examples + -------- + Users can use the FX tracer or dynamo.export() to extract + a fx.GraphModule from a PyTorch model. The following codes show + how to convert a PyTorch model to a Relax program. + + .. code-block:: python + + # Import the importer. + import numpy as np + import torch + from tvm.relax.frontend.torch_fx import from_fx + from torch import _dynamo as dynamo + + # Define the module + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True) + + def forward(self, input): + return self.linear(input) + + # Instantiate the model and create the input info dict. + torch_model = MyModule() + input_info = [((128, 10), "float32")] + input_tensors = [ + torch.astensor(np.random.randn(*shape).astype(dtype)) + for shape, dtype in input_info + ] + + # Use FX tracer to trace the PyTorch model. + graph_module = fx.symbolic_trace(torch_model) + + # Use the dynamo.export() to export the PyTorch model to FX. + try: + graph_module = dynamo.export(torch_model, *input_tensors) + except: + raise RuntimeError("Failed to export the PyTorch model to FX.") + + # Use the importer to import the PyTorch model to Relax. + mod: tvm.IRModule = from_fx(graph_module, input_info) + + # Print out the imported model. + print(mod.script()) + + Notes + ----- + For a given PyTorch model, to lookup the names of the model inputs in + FX, one can use + + .. code-block:: python + + fx.symbolic_trace(model).graph.print_tabular() + + to print out the tabular representation of the PyTorch module, and then + check the placeholder rows in the beginning of the tabular. + """ + return TorchFXImporter().from_fx(model, input_info) diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py new file mode 100644 index 0000000000..a5da15b76d --- /dev/null +++ b/python/tvm/relax/pipeline.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Pre-defined pipelines. + +oRelax enables flexible pipeline optimizations before min build. +This namespace offers a pre-defined collection that can be used +as it is or serves as a basis to do further composition. +""" +# pylint: disable=unused-argument +import tvm +from tvm import meta_schedule as ms +from . import transform + + [email protected]_pass(opt_level=0) +def zero_pipeline(mod: tvm.ir.IRModule, ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + """Pipeline that applies pre-tuned logs. + + Parameters + ---------- + mod : tvm.ir.IRModule + Input IRModule. + + ctx : tvm.transform.PassContext + The pass context + + Returns + ------- + mod: tvm.ir.IRModule + The result transformed module. + """ + seq = tvm.transform.Sequential( + [ + transform.LegalizeOps(), + transform.AnnotateTIROpPattern(), + transform.FoldConstant(), + transform.FuseOps(), + transform.FuseTIR(), + ] + ) + mod = seq(mod) + if ms.Database.current(): + mod = transform.MetaScheduleApplyDatabase()(mod) + return mod + + +# global map of pre-built pipelines +PIPELINE_MAP = {"zero": zero_pipeline} + + +def get_pipeline(name: str = "zero") -> tvm.transform.Pass: + """Get pre-build pipeline by name + + Parameters + ---------- + name : Optional[str] + Name of the pipeline + + Returns + ------- + pipeline: tvm.transform.Pass + The transformation pipeline. + """ + + if name in PIPELINE_MAP: + return PIPELINE_MAP[name] + else: + raise ValueError( + f"Unknown pre-built pipeline {name}," f"candidates are {list(PIPELINE_MAP.keys())}" + ) diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py new file mode 100644 index 0000000000..370df2103d --- /dev/null +++ b/tests/python/relax/test_frontend_dynamo.py @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("torch._dynamo") + + +import tvm +from tvm import relax, meta_schedule as ms, tir +import tvm.testing +import torch +import torch._dynamo as dynamo +from tvm.relax.frontend.torch import relax_dynamo +from tvm.script.parser import relax as R, tir as T + + +def test_relax_dynamo(): + class Input1(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(100, 10) + + def forward(self, x): + return torch.nn.functional.relu(self.lin(x)) + + model = Input1() + ### construct the database + @tvm.script.ir_module + class Input1_ir: + @T.prim_func + def main( + inp_0: T.Buffer[(T.int64(10), T.int64(100)), "float32"], + param_0: T.Buffer[(T.int64(100), T.int64(10)), "float32"], + param_1: T.Buffer[T.int64(10), "float32"], + compute: T.Buffer[(T.int64(10), T.int64(10)), "float32"], + ): + # function attr dict + T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + matmul = T.alloc_buffer([T.int64(10), T.int64(10)], dtype="float32") + T_add = T.alloc_buffer([T.int64(10), T.int64(10)], dtype="float32") + for i0, i1, k in T.grid(T.int64(10), T.int64(10), T.int64(100)): + with T.block("matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + T.reads(inp_0[v_i0, v_k], param_0[v_k, v_i1]) + T.writes(matmul[v_i0, v_i1]) + with T.init(): + matmul[v_i0, v_i1] = T.float32(0) + matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + inp_0[v_i0, v_k] * param_0[v_k, v_i1] + for ax0, ax1 in T.grid(T.int64(10), T.int64(10)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(matmul[v_ax0, v_ax1], param_1[v_ax1]) + T.writes(T_add[v_ax0, v_ax1]) + T_add[v_ax0, v_ax1] = matmul[v_ax0, v_ax1] + param_1[v_ax1] + for i0, i1 in T.grid(T.int64(10), T.int64(10)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(T_add[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.max(T_add[v_i0, v_i1], T.float32(0)) + + db = ms.Database.create("memory") + workload = db.commit_workload(Input1_ir) + + sch = tir.Schedule(Input1_ir, debug_mask="all") + b0 = sch.get_block(name="matmul", func_name="main") + b1 = sch.get_block(name="T_add", func_name="main") + b2 = sch.get_block(name="root", func_name="main") + sch.compute_inline(block=b1) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") + l3, l4, l5 = sch.get_loops(block=b0) + v6, v7, v8, v9 = sch.sample_perfect_tile( + loop=l3, n=4, max_innermost_factor=64, decision=[1, 2, 5, 1] + ) + l10, l11, l12, l13 = sch.split(loop=l3, factors=[v6, v7, v8, v9], preserve_unit_iters=True) + v14, v15, v16, v17 = sch.sample_perfect_tile( + loop=l4, n=4, max_innermost_factor=64, decision=[1, 1, 10, 1] + ) + l18, l19, l20, l21 = sch.split(loop=l4, factors=[v14, v15, v16, v17], preserve_unit_iters=True) + v22, v23 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64, decision=[100, 1]) + l24, l25 = sch.split(loop=l5, factors=[v22, v23], preserve_unit_iters=True) + sch.reorder(l10, l18, l11, l19, l24, l12, l20, l25, l13, l21) + (b26,) = sch.get_consumers(block=b0) + sch.reverse_compute_at(block=b26, loop=l18, preserve_unit_loops=True, index=-1) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.parallel", ann_val=96) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.vectorize", ann_val=64) + v27 = sch.sample_categorical( + candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=0 + ) + sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v27) + + tuning_record = ms.database.TuningRecord(sch.trace, workload, run_secs=[0.0]) + db.commit_tuning_record(tuning_record) + ### Optimize the model with tuned-log + with db: + opt_model = torch.compile(model, backend=relax_dynamo()) + inp = torch.randn(10, 100) + tvm.testing.assert_allclose( + opt_model(inp).detach().numpy(), model(inp).detach().numpy(), rtol=1e-5, atol=1e-5 + ) + + +def test_subgraph_capture(): + import torch + from tvm.relax.frontend.torch.dynamo import dynamo_capture_subgraphs + + class Input1(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(100, 10) + + def forward(self, x): + return torch.nn.functional.relu(self.lin(x)) + + @tvm.script.ir_module + class Expected1: + @R.function + def subgraph_0( + inp_0: R.Tensor((10, 100), dtype="float32"), + w0: R.Tensor((10, 100), dtype="float32"), + w1: R.Tensor((10,), dtype="float32"), + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((100, 10), dtype="float32") = R.permute_dims(w0, axes=None) + lv1: R.Tensor((10, 10), dtype="float32") = R.matmul(inp_0, lv, out_dtype="float32") + lv2: R.Tensor((10, 10), dtype="float32") = R.add(lv1, w1) + lv3: R.Tensor((10, 10), dtype="float32") = R.nn.relu(lv2) + gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + model = Input1() + mod = dynamo_capture_subgraphs(model, torch.randn(10, 100)) + binding = {"w0": model.lin.weight.detach().numpy(), "w1": model.lin.bias.detach().numpy()} + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + expected = relax.transform.BindParams("subgraph_0", binding)(Expected1) + tvm.ir.assert_structural_equal(mod, expected) + + def Input2(a, b): + x = a / (torch.sin(a) + 1) + if torch.sum(b) < 1: + b = b * -1 + return x * b + + @tvm.script.ir_module + class Expected2: + @R.function + def subgraph_0( + inp_0: R.Tensor((10,), dtype="float32"), inp_1: R.Tensor((10,), dtype="float32") + ) -> R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((), dtype="bool")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10,), dtype="float32") = R.sin(inp_0) + lv1: R.Tensor((10,), dtype="float32") = R.add(lv, R.const(1, "float32")) + lv2: R.Tensor((10,), dtype="float32") = R.divide(inp_0, lv1) + lv3: R.Tensor((), dtype="float32") = R.sum(inp_1, axis=None, keepdims=False) + lv4: R.Tensor((), dtype="bool") = R.less(lv3, R.const(1, "float32")) + gv: R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((), dtype="bool")) = ( + lv2, + lv4, + ) + R.output(gv) + return gv + + @R.function + def subgraph_1( + inp_01: R.Tensor((10,), dtype="float32"), inp_11: R.Tensor((10,), dtype="float32") + ) -> R.Tuple(R.Tensor((10,), dtype="float32")): + # block 0 + with R.dataflow(): + lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_11, inp_01) + gv1: R.Tuple(R.Tensor((10,), dtype="float32")) = (lv5,) + R.output(gv1) + return gv1 + + mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10)) + tvm.ir.assert_structural_equal(mod, Expected2) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py new file mode 100644 index 0000000000..9b35d34bd3 --- /dev/null +++ b/tests/python/relax/test_frontend_from_fx.py @@ -0,0 +1,1729 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import relax +import tvm.testing +from tvm.script.parser import relax as R, tir as T + + +def verify_model(torch_model, input_info, binding, expected): + from torch import fx + from tvm.relax.frontend.torch import from_fx + + graph_model = fx.symbolic_trace(torch_model) + mod = from_fx(graph_model, input_info) + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + expected = relax.transform.BindParams("main", binding)(expected) + tvm.ir.assert_structural_equal(mod, expected) + + [email protected]_gpu +def test_conv(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + class Conv2D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1]) + lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3 + R.output(gv) + return gv + + class Conv2D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + ) -> R.Tensor((1, 6, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv1 + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + + model = Conv2D1() + binding = {"w1": model.conv.weight.numpy(), "w2": model.conv.bias.numpy()} + verify_model(model, input_info, binding, expected1) + + model = Conv2D2() + binding = {"w1": model.conv.weight.numpy()} + verify_model(model, input_info, binding, expected2) + + [email protected]_gpu +def test_linear(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + # nn.Linear + class Dense1(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=True) + + def forward(self, input): + return self.linear(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((7, 10), dtype="float32"), + w2: R.Tensor((1, 7), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 7), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) + lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( + input_1, lv, out_dtype="float32" + ) + lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2) + gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv2 + R.output(gv) + return gv + + class Dense2(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=False) + + def forward(self, input): + return self.linear(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((7, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 7), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) + lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( + input_1, lv, out_dtype="float32" + ) + gv: R.Tensor((1, 3, 10, 7), dtype="float32") = lv1 + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + + model = Dense1() + binding = {"w1": model.linear.weight.numpy(), "w2": model.linear.bias.numpy()} + verify_model(model, input_info, binding, expected1) + + model = Dense2() + binding = {"w1": model.linear.weight.numpy()} + verify_model(model, input_info, binding, expected2) + + # matmul + class MatMul1(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32"), + input_2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul( + input_1, input_2, out_dtype="float32" + ) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model( + MatMul1(), + [([10, 10], "float32"), ([10, 10], "float32")], + {}, + expected3, + ) + + [email protected]_gpu +def test_relu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class ReLU0(Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, input): + return self.relu(input) + + class ReLU1(Module): + def forward(self, input): + return torch.nn.functional.relu(input) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.nn.relu(input_1) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([10, 10], "float32")] + verify_model(ReLU0(), input_info, {}, expected) + verify_model(ReLU1(), input_info, {}, expected) + + [email protected]_gpu +def test_relu6(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class ReLU6(Module): + def __init__(self): + super().__init__() + self.relu6 = torch.nn.ReLU6() + + def forward(self, input): + return self.relu6(input) + + @tvm.script.ir_module + class expected: + @R.function + def main(input: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.clip(input, 0, 6) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([10, 10], "float32")] + verify_model(ReLU6(), input_info, {}, expected) + + [email protected]_gpu +def test_maxpool2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class MaxPool2d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[1, 1], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool2d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 4, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[2, 2], + strides=[2, 2], + dilation=[2, 3], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tensor((1, 3, 4, 4), dtype="float32") = lv + R.output(gv) + return gv + + class MaxPool2d3(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 6, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 6, 6), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[4, 4], + strides=[2, 2], + dilation=[1, 1], + padding=[2, 2, 2, 2], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tensor((1, 3, 6, 6), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(MaxPool2d(), input_info, {}, expected1) + verify_model(MaxPool2d2(), input_info, {}, expected2) + verify_model(MaxPool2d3(), input_info, {}, expected3) + + [email protected]_gpu +def test_adaptive_avgpool2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class AdaptiveAvgPool2d0(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool2d1(Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.adaptive_avg_pool2d( + input_1, output_size=[10, 10], layout="NCHW", out_layout="NCHW" + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(AdaptiveAvgPool2d0(), input_info, {}, expected1) + verify_model(AdaptiveAvgPool2d1(), input_info, {}, expected1) + + [email protected]_gpu +def test_flatten(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Flatten(Module): + def __init__(self): + super().__init__() + self.f = torch.nn.Flatten(2, -1) + + def forward(self, input): + return self.f(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 100), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 100), dtype="float32") = R.reshape(input_1, (1, 3, 100)) + gv: R.Tensor((1, 3, 100), dtype="float32") = lv + R.output(gv) + return gv + + # call_module + verify_model(Flatten(), input_info, {}, expected1) + # call_method + verify_model(torch.nn.Flatten(2, -1), input_info, {}, expected1) + + [email protected]_gpu +def test_batchnorm2d(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class BatchNorm2d(Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, input): + return self.bn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + w3: R.Tensor((3,), dtype="float32"), + w4: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="float32"), + R.Tensor((3,), dtype="float32"), + R.Tensor((3,), dtype="float32"), + ) = R.nn.batch_norm( + input_1, + w1, + w2, + w3, + w4, + axis=1, + epsilon=1e-05, + center=True, + scale=True, + ) + lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv1 + R.output(gv) + return gv + + model = BatchNorm2d() + binding = { + "w1": model.bn.weight.numpy(), + "w2": model.bn.bias.numpy(), + "w3": model.bn.running_mean.numpy(), + "w4": model.bn.running_var.numpy(), + } + verify_model(BatchNorm2d(), input_info, binding, expected1) + + [email protected]_gpu +def test_embedding(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([4], "int64")] + + class Embedding(Module): + def __init__(self): + super().__init__() + self.embedding = torch.nn.Embedding(10, 3) + + def forward(self, input): + return self.embedding(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3), dtype="float32") + ) -> R.Tensor((4, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4,), dtype="int32") = R.astype(input_1, dtype="int32") + lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0) + gv: R.Tensor((4, 3), dtype="float32") = lv1 + R.output(gv) + return gv + + model = Embedding() + binding = {"w1": model.embedding.weight.numpy()} + verify_model(model, input_info, binding, expected1) + + [email protected]_gpu +def test_dropout(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Dropout(Module): + def __init__(self): + super().__init__() + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, input): + return self.dropout(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = input_1 + R.output(gv) + return gv + + verify_model(Dropout(), input_info, {}, expected1) + + [email protected]_gpu +def test_layernorm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class LayerNorm(Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm((10, 10)) + + def forward(self, input): + return self.ln(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((10, 10), dtype="float32"), + w2: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm( + input_1, + w1, + w2, + axes=[-2, -1], + epsilon=1e-05, + center=True, + scale=True, + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + model = LayerNorm() + binding = { + "w1": model.ln.weight.numpy(), + "w2": model.ln.bias.numpy(), + } + verify_model(LayerNorm(), input_info, binding, expected1) + + [email protected]_gpu +def test_silu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class SiLU(Module): + def __init__(self): + super().__init__() + self.silu = torch.nn.SiLU() + + def forward(self, input): + return self.silu(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.silu(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(SiLU(), input_info, {}, expected1) + + [email protected]_gpu +def test_groupnorm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class GroupNorm(Module): + def __init__(self): + super().__init__() + self.gn = torch.nn.GroupNorm(3, 3) + + def forward(self, input): + return self.gn(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((3,), dtype="float32"), + w2: R.Tensor((3,), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.reshape( + input_1, (1, 3, 1, 10, 10) + ) + lv1: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.mean( + lv, axis=[2, 3, 4], keepdims=True + ) + lv2: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.subtract(lv, lv1) + lv3: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv2, lv2) + lv4: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sum( + lv3, axis=[2, 3, 4], keepdims=True + ) + lv5: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.divide(lv4, R.const(100.0)) + lv6: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.add(lv5, R.const(1e-05)) + lv7: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.sqrt(lv6) + lv8: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.divide(lv2, lv7) + lv9: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w1, (1, 3, 1, 1, 1)) + lv10: R.Tensor((1, 3, 1, 1, 1), dtype="float32") = R.reshape(w2, (1, 3, 1, 1, 1)) + lv11: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.multiply(lv8, lv9) + lv12: R.Tensor((1, 3, 1, 10, 10), dtype="float32") = R.add(lv11, lv10) + lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = R.reshape(lv12, (1, 3, 10, 10)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv13 + R.output(gv) + return gv + + model = GroupNorm() + binding = { + "w1": model.gn.weight.numpy(), + "w2": model.gn.bias.numpy(), + } + verify_model(model, input_info, binding, expected1) + + [email protected]_gpu +def test_softmax(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Softmax(Module): + def __init__(self): + super().__init__() + self.sm = torch.nn.Softmax(dim=1) + + def forward(self, input): + return self.sm(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softmax(input_1, axis=1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Softmax(), input_info, {}, expected1) + + [email protected]_gpu +def test_binary(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] + input_info2 = [([1, 3, 10, 10], "float32")] + # Add + class Add1(Module): + def forward(self, lhs, rhs): + return lhs + rhs + + @tvm.script.ir_module + class expected1: + @R.function + def main( + lhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs, rhs) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Add2(Module): + def forward(self, lhs): + return lhs + 1.0 + + @tvm.script.ir_module + class expected2: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Add1(), input_info1, {}, expected1) + verify_model(Add2(), input_info2, {}, expected2) + + # Sub + class Sub1(Module): + def forward(self, lhs, rhs): + return lhs - rhs + + @tvm.script.ir_module + class expected3: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Sub2(Module): + def forward(self, lhs): + return lhs - 1.0 + + @tvm.script.ir_module + class expected4: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sub1(), input_info1, {}, expected3) + verify_model(Sub2(), input_info2, {}, expected4) + + # Mul + class Mul1(Module): + def forward(self, lhs, rhs): + return lhs * rhs + + @tvm.script.ir_module + class expected5: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Mul2(Module): + def forward(self, lhs): + return lhs * 1.0 + + @tvm.script.ir_module + class expected6: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Mul1(), input_info1, {}, expected5) + verify_model(Mul2(), input_info2, {}, expected6) + + # True div + class TrueDiv1(Module): + def forward(self, lhs, rhs): + return lhs / rhs + + @tvm.script.ir_module + class expected7: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class TrueDiv2(Module): + def forward(self, lhs): + return lhs / 1.0 + + @tvm.script.ir_module + class expected8: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(TrueDiv1(), input_info1, {}, expected7) + verify_model(TrueDiv2(), input_info2, {}, expected8) + + # Floor div + class FloorDiv1(Module): + def forward(self, lhs, rhs): + return lhs // rhs + + @tvm.script.ir_module + class expected9: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.floor_divide(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class FloorDiv2(Module): + def forward(self, lhs): + return lhs // 1.0 + + @tvm.script.ir_module + class expected10: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.floor_divide(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(FloorDiv1(), input_info1, {}, expected9) + verify_model(FloorDiv2(), input_info2, {}, expected10) + + # LT + class LT1(Module): + def forward(self, lhs, rhs): + return lhs < rhs + + @tvm.script.ir_module + class expected11: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="bool"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1, rhs_1) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv + R.output(gv) + return gv + + class LT2(Module): + def forward(self, lhs): + return lhs < 1.0 + + @tvm.script.ir_module + class expected12: + @R.function + def main( + lhs_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tensor((1, 3, 10, 10), dtype="bool"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(lhs_1, R.const(1.0)) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv + R.output(gv) + return gv + + verify_model(LT1(), input_info1, {}, expected11) + verify_model(LT2(), input_info2, {}, expected12) + + [email protected]_gpu +def test_size(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Size(Module): + def forward(self, input): + return input.size() + + @tvm.script.ir_module + class expected1: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 10, 10]): + # block 0 + with R.dataflow(): + gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10]) + R.output(gv) + return gv + + verify_model(Size(), input_info, {}, expected1) + + [email protected]_gpu +def test_unsqueeze(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Unsqueeze1(Module): + def forward(self, input): + return input.unsqueeze(1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = R.expand_dims(input_1, 1) + gv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + class Unsqueeze2(Module): + def forward(self, input): + return input.unsqueeze(-1) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10, 1), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = R.expand_dims(input_1, -1) + gv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Unsqueeze1(), input_info, {}, expected1) + verify_model(Unsqueeze2(), input_info, {}, expected2) + + [email protected]_gpu +def test_getattr(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class GetAttr1(Module): + def forward(self, input): + return input.shape + + @tvm.script.ir_module + class expected1: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 10, 10]): + # block 0 + with R.dataflow(): + gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10]) + R.output(gv) + return gv + + verify_model(GetAttr1(), input_info, {}, expected1) + + [email protected]_gpu +def test_getitem(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Slice1(Module): + def forward(self, x): + return x[0, 1::2, :, :3] + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 1, 10, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 1, 10, 3), dtype="float32") = R.strided_slice( + x, + axes=[0, 1, 2, 3], + begin=[0, 1, 0, 0], + end=[1, T.int64(3), T.int64(10), 3], + strides=[1, 2, 1, 1], + ) + lv1: R.Tensor((1, 1, 10, 3), dtype="float32") = R.reshape(lv, (1, 1, 10, 3)) + gv: R.Tensor((1, 1, 10, 3), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Slice1(), input_info, {}, expected1) + + [email protected]_gpu +def test_unary(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + # sin + class Sin(Module): + def forward(self, input): + return torch.sin(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sin(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sin(), input_info, {}, expected1) + + # cos + class Cos(Module): + def forward(self, input): + return torch.cos(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.cos(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Cos(), input_info, {}, expected2) + + # sqrt + class Sqrt(Module): + def forward(self, input): + return torch.sqrt(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sqrt(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sqrt(), input_info, {}, expected3) + + [email protected]_gpu +def test_gelu(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Gelu(Module): + def forward(self, input): + return torch.nn.functional.gelu(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.gelu(input_1) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Gelu(), input_info, {}, expected1) + + [email protected]_gpu +def test_clamp(): + import torch + from torch import fx + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Clamp(Module): + def forward(self, input): + return torch.clamp(input, min=0.1, max=0.5) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(input_1, 0.1, 0.5) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Clamp(), input_info, {}, expected1) + + from tvm.relax.frontend.torch import from_fx + + with pytest.raises( + ValueError, match="TVM only supports constant max value for torch.clamp/clip" + ): + + class Clamp_Error(Module): + def forward(self, input): + return torch.clamp(input, min=0.5, max=None) + + gm = fx.symbolic_trace(Clamp_Error()) + from_fx(gm, input_info) + + with pytest.raises( + ValueError, match="TVM only supports constant min value for torch.clamp/clip" + ): + + class Clamp_Error(Module): + def forward(self, input): + return torch.clamp(input, min=input, max=input) + + gm = fx.symbolic_trace(Clamp_Error()) + from_fx(gm, input_info) + + [email protected]_gpu +def test_interpolate(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Interpolate(Module): + def forward(self, input): + return torch.nn.functional.interpolate(input, size=(5, 5)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 5, 5), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 5), dtype="float32") = R.image.resize2d( + input_1, + (5, 5), + roi=[0.000000, 0.000000, 0.000000, 0.000000], + layout="NCHW", + method="nearest_neighbor", + coordinate_transformation_mode="asymmetric", + rounding_method="round", + cubic_alpha=-0.5, + cubic_exclude=0, + extrapolation_value=0, + out_dtype="", + ) + gv: R.Tensor((1, 3, 5, 5), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Interpolate(), input_info, {}, expected1) + + [email protected]_gpu +def test_addmm(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [ + ([10, 10], "float32"), + ([10, 10], "float32"), + ([10, 10], "float32"), + ] + + class Addmm(Module): + def forward(self, x1, x2, x3): + return torch.addmm(x1, x2, x3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x1: R.Tensor((10, 10), dtype="float32"), + x2: R.Tensor((10, 10), dtype="float32"), + x3: R.Tensor((10, 10), dtype="float32"), + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") + lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv) + gv: R.Tensor((10, 10), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(Addmm(), input_info, {}, expected1) + + [email protected]_gpu +def test_split(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 3, 10, 10], "float32")] + + class Split(Module): + def forward(self, input): + return torch.split(input, 1, dim=1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ): + # block 0 + with R.dataflow(): + lv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = R.split(input_1, indices_or_sections=3, axis=1) + gv: R.Tuple( + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + R.Tensor((1, 1, 10, 10), dtype="float32"), + ) = lv + R.output(gv) + return gv + + verify_model(Split(), input_info, {}, expected1) + + [email protected]_gpu +def test_tril(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([10, 10], "float32")] + + class Tril(Module): + def forward(self, input): + return torch.tril(input, 1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tril(), input_info, {}, expected1) + + [email protected]_gpu +def test_new_ones(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3], "float32")] + + class NewOnes(Module): + def forward(self, x): + return x.new_ones(1, 2, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3), dtype="float32")) -> R.Tensor((1, 2, 3), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3), dtype="float32") = R.full( + (1, 2, 3), R.const(1, "float32"), dtype="float32" + ) + gv: R.Tensor((1, 2, 3), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(NewOnes(), input_info, {}, expected1) + + [email protected]_gpu +def test_expand(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Expand(Module): + def forward(self, x): + return x.expand(4, 2, 3, 4) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((4, 2, 3, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 2, 3, 4), dtype="float32") = R.broadcast_to(x, (4, 2, 3, 4)) + gv: R.Tensor((4, 2, 3, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Expand(), input_info, {}, expected1) + + [email protected]_gpu +def test_reduce(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + # sum + class Sum(Module): + def forward(self, x): + return torch.sum(x, (2, 1)) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 1], keepdims=False) + gv: R.Tensor((1, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Sum(), input_info, {}, expected1) + + [email protected]_gpu +def test_to(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + # float + class ToFloat(Module): + def forward(self, x): + return x.float() + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 2, 3, 4), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32") + gv: R.Tensor((1, 2, 3, 4), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(ToFloat(), input_info, {}, expected1) + + # half + class ToHalf(Module): + def forward(self, x): + return x.half() + + @tvm.script.ir_module + class expected2: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 2, 3, 4), dtype="float16"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x, dtype="float16") + gv: R.Tensor((1, 2, 3, 4), dtype="float16") = lv + R.output(gv) + return gv + + verify_model(ToHalf(), input_info, {}, expected2) + + [email protected]_gpu +def test_permute(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Permute(Module): + def forward(self, x): + return x.permute(0, 3, 2, 1) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4, 3, 2), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Permute(), input_info, {}, expected1) + + [email protected]_gpu +def test_reshape(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Reshape(Module): + def forward(self, x): + return x.reshape(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tensor((2, 12), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Reshape(), input_info, {}, expected1) + + [email protected]_gpu +def test_transpose(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class Transpose(Module): + def forward(self, x): + return x.transpose(1, 3) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tensor((1, 4, 3, 2), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) + gv: R.Tensor((1, 4, 3, 2), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Transpose(), input_info, {}, expected1) + + [email protected]_gpu +def test_view(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + torch.random.manual_seed(0) + + input_info = [([1, 2, 3, 4], "float32")] + + class View(Module): + def forward(self, x): + return x.view(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((2, 12), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tensor((2, 12), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(View(), input_info, {}, expected1) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_pipeline.py b/tests/python/relax/test_pipeline.py new file mode 100644 index 0000000000..6d6704ae97 --- /dev/null +++ b/tests/python/relax/test_pipeline.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import relax +from tvm.script import relax as R + + +def test_pipeline_compile(): + pipeline = relax.get_pipeline() + + @tvm.script.ir_module + class Mod: + @R.function + def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): + lv0 = R.add(x, y) + return lv0 + + mod = Mod + mod = pipeline(mod) + target = tvm.target.Target("llvm", host="llvm") + + ex = relax.vm.build(mod, target) + x_np = np.random.rand(3, 4).astype(np.float32) + y_np = np.random.rand(3, 4).astype(np.float32) + x = tvm.nd.array(x_np) + y = tvm.nd.array(y_np) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + z = vm["main"](x, y) + tvm.testing.assert_allclose(z.numpy(), x_np + y_np, rtol=1e-7, atol=1e-7)
