This is an automated email from the ASF dual-hosted git repository. tkonolige pushed a commit to branch tkonolige/relax_pad_etc_new in repository https://gitbox.apache.org/repos/asf/tvm.git
commit a54064659bd6e1239ec612864130428f78503bc3 Author: Tristan Konolige <[email protected]> AuthorDate: Fri May 12 21:48:50 2023 +0000 working --- include/tvm/relax/attrs/index.h | 8 + python/tvm/relax/expr.py | 2 +- python/tvm/relax/frontend/torch/fx_translator.py | 204 +++++++++++++++++++++-- src/relax/op/tensor/index.cc | 1 - 4 files changed, 203 insertions(+), 12 deletions(-) diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h index c95395a803..e31a334bf3 100644 --- a/include/tvm/relax/attrs/index.h +++ b/include/tvm/relax/attrs/index.h @@ -56,6 +56,14 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> { } }; // struct StridedSliceAttrs +/*! \brief Attributes used in strided_slice operator */ +struct DataDependentStridedSliceAttrs : public tvm::AttrsNode<DataDependentStridedSliceAttrs> { + Array<Integer> axes; + TVM_DECLARE_ATTRS(DataDependentStridedSliceAttrs, "relax.attrs.DataDependentStridedSliceAttrs") { + TVM_ATTR_FIELD(axes).describe("Axes along which slicing is applied."); + } +}; + } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index fdf98c179b..d3550741f7 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -652,7 +652,7 @@ def const( value = _nd.array(value) if not isinstance(value, _nd.NDArray): - raise ValueError("value has to be scalar or NDArray") + raise ValueError(f"value has to be scalar or NDArray but it is {type(value)}") return Constant(value) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 7166b312a0..2d94b246de 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -23,7 +23,105 @@ from functools import reduce import tvm from tvm import relax - +import numpy as np + +def _pytorch_result_type(dtypes, non_tensor_inputs): + """This promotes TVM dtypes like PyTorch would""" + import torch + + dtype_map = { + "float64": torch.float64, + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "int64": torch.int64, + "int32": torch.int32, + "int16": torch.int16, + "int8": torch.int8, + "uint8": torch.uint8, + "bool": torch.bool, + } + if len(dtypes) > 0: + result_type = dtypes[0] + for dt in dtypes[1:]: + if dt != result_type: # we don't want to work with same types as we + # don't do quantized here (which cannot be promoted?) + result_type = _convert_data_type( + str( + torch.result_type( + torch.zeros((), dtype=dtype_map[result_type]), + torch.zeros((), dtype=dtype_map[dt]), + ) + ) + ) + else: + result_type = "bool" # this is the smallest type... + for inp in non_tensor_inputs: + result_type = _convert_data_type( + str(torch.result_type(torch.zeros((), dtype=dtype_map[result_type]), inp)) + ) + return result_type + + +# Helper functions for operator implementation +def _convert_dtype_value(val): + """converts a PyTorch the PyTorch numeric type id to a torch scalar type.""" + convert_torch_dtype_map = { + 11: "torch.bool", + 7: "torch.float64", + 6: "torch.float32", + 5: "torch.float16", + 4: "torch.int64", + 3: "torch.int32", + 2: "torch.int16", + 1: "torch.int8", + 0: "torch.uint8", + None: "torch.int64", + } # Default is torch.int64 + if val in convert_torch_dtype_map: + return _convert_data_type(convert_torch_dtype_map[val]) + else: + msg = "Torch data type value %d is not handled yet." % (val) + raise NotImplementedError(msg) + + +def _convert_data_type(input_type, default_dtype=None): + """converts the PyTorch scalar type input_type to a TVM dtype. + optionally, default_dtype can be a TVM dtype that is used + if input_type is None (but not when it is unknown)""" + if input_type is None and default_dtype is not None: + return default_dtype + + input_type = input_type.lower() + if input_type in ["double", "float64", "torch.float64"]: + return "float64" + elif input_type in ["float", "float32", "torch.float32"]: + return "float32" + elif input_type in ["half", "float16", "torch.float16"]: + return "float16" + elif input_type in ["long", "int64", "torch.int64"]: + return "int64" + elif input_type in ["int", "int32", "torch.int32"]: + return "int32" + elif input_type in ["short", "int16", "torch.int16"]: + return "int16" + elif input_type in ["char", "int8", "torch.int8"]: + return "int8" + elif input_type in ["byte", "uint8", "torch.uint8"]: + return "uint8" + elif input_type in ["quint8", "torch.quint8"]: + return "quint8" + elif input_type in ["qint8", "torch.qint8"]: + return "qint8" + elif input_type in ["qint32", "torch.qint32"]: + return "qint32" + elif input_type in ["bool", "torch.bool"]: + return "bool" + elif input_type in ["str"]: + return "str" + else: + raise NotImplementedError("input_type {} is not handled yet".format(input_type)) + return "float32" # Never reached class TorchFXImporter: """An importer from PyTorch FX to Relax.""" @@ -86,6 +184,23 @@ class TorchFXImporter: dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype)) return relax.const(tensor.data.numpy(), dtype) + @staticmethod + def _promote_types(inputs): + """This promotes TVM inputs with TVM dtypes passed like PyTorch would""" + dtypes = [inp.struct_info.dtype for inp in inputs] + tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not np.isscalar(inp)] + non_tensor_inputs = [inp for inp in inputs if np.isscalar(inp)] + result_type = _pytorch_result_type(tensor_dtypes, non_tensor_inputs) + results = [] + for inp, dt in zip(inputs, dtypes): + if np.isscalar(inp): + results.append(relax.const(inp, dtype=result_type)) + elif dt == result_type: + results.append(inp) + else: + results.append(relax.op.astype(inp, result_type)) + return results + @staticmethod def shape_of(tensor): """Get the shape of a tensor.""" @@ -119,7 +234,8 @@ class TorchFXImporter: @staticmethod def _promote_binary_op_args(lhs, rhs): if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr): - return lhs, rhs + lhs_, rhs_ = TorchFXImporter._promote_types((lhs, rhs)) + return lhs_, rhs_ elif isinstance(lhs, relax.Expr): assert isinstance(lhs.struct_info, relax.TensorStructInfo) return lhs, relax.const(rhs, lhs.struct_info.dtype) @@ -151,6 +267,8 @@ class TorchFXImporter: arg = self.env[node.args[0]] if isinstance(arg, (int, float)): arg = relax.const(arg, "float32") + if isinstance(arg, (tvm.tir.FloatImm, tvm.tir.IntImm)): + arg = relax.const(arg.value, "float32") return self.block_builder.emit(relax.op.sqrt(arg)) def _rsqrt(self, node: fx.node.Node) -> relax.Expr: @@ -180,9 +298,8 @@ class TorchFXImporter: return lhs + rhs def _max(self, node: fx.node.Node) -> relax.Expr: - lhs, rhs = self.retrieve_args(node) - if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): - return self._call_binary_op(relax.op.maximum, lhs, rhs) + lhs, = self.retrieve_args(node) + return self.block_builder.emit(relax.op.max(lhs)) def _floordiv(self, node: fx.node.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) @@ -190,6 +307,10 @@ class TorchFXImporter: return self._call_binary_op(relax.op.floor_divide, lhs, rhs) return lhs // rhs + def _floor(self, node: fx.node.Node) -> relax.Expr: + lhs, = self.retrieve_args(node) + return self.block_builder.emit(relax.op.floor(lhs)) + def _mul(self, node: fx.node.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): @@ -248,6 +369,10 @@ class TorchFXImporter: lhs, rhs = self.retrieve_args(node) return self._call_binary_op(relax.op.equal, lhs, rhs) + def _ne(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + return self._call_binary_op(relax.op.not_equal, lhs, rhs) + ########## Creation ########## def _arange(self, node: fx.node.Node) -> relax.Var: @@ -501,7 +626,13 @@ class TorchFXImporter: def _expand(self, node: fx.node.Node) -> relax.Var: args = self.retrieve_args(node) - return self.block_builder.emit(relax.op.broadcast_to(args[0], args[1:])) + # -1 indicates dimension remains the same + fixed_dims = [args[0].struct_info.shape[i] if d == -1 else d for i, d in enumerate(args[1:])] + return self.block_builder.emit(relax.op.broadcast_to(args[0], fixed_dims)) + + def _expand_as(self, node: fx.node.Node) -> relax.Var: + args = self.retrieve_args(node) + return self.block_builder.emit(relax.op.broadcast_to(args[0], self.shape_of(args[1]))) def _flatten(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] @@ -574,6 +705,19 @@ class TorchFXImporter: dim = None return self.block_builder.emit(relax.op.squeeze(x, dim)) + def _repeat_interleave(self, node: fx.node.Node) -> relax.Var: + data, repeats = self.retrieve_args(node) + if not "dim" in node.kwargs: + data = relax.op.reshape(args[0], [-1]) + axis = 0 + else: + axis = node.kwargs["dim"] + return self.block_builder.emit(relax.op.repeat(data, repeats=repeats.value, axis=axis)) + + def _stack(self, node: fx.node.Node) -> relax.Var: + args = [self.env[x] for x in node.args[0]] + return self.block_builder.emit(relax.op.concat(args, node.kwargs["dim"])) + def _cumsum(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] @@ -711,6 +855,34 @@ class TorchFXImporter: return self.block_builder.emit(relax.op.add(conv2d, bias)) + def _convtranpose2d(self, node: fx.node.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + + conv2d = self.block_builder.emit( + relax.op.nn.conv2d_transpose( + x, + weight, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + data_layout="NCHW", + kernel_layout="IOHW", + out_dtype="float32", + ) + ) + + if module.bias is None: + return conv2d + + bias = self.params[module.bias] + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + + return self.block_builder.emit(relax.op.add(conv2d, bias)) + def _max_pool2d(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: @@ -1135,10 +1307,11 @@ class TorchFXImporter: i = 0 shape = self.shape_of(x) non_ellipsis_cnt = 0 - for index in node.args[1]: + idxs = [node.args[1]] if isinstance(node.args[1], int) else node.args[1] + for index in idxs: if isinstance(index, (int, slice)): non_ellipsis_cnt += 1 - for index in node.args[1]: + for index in idxs: if isinstance(index, int): begin.append(index) end.append(index + 1) @@ -1147,7 +1320,12 @@ class TorchFXImporter: i = i + 1 elif isinstance(index, slice): begin.append(0 if index.start is None else index.start) - end.append(shape[i] if index.stop is None else index.stop) + if index.stop is None: + end.append(shape[i]) + elif isinstance(index.stop, fx.node.Node): + end.append(self.env[index.stop]) + else: + end.append(index.stop) stride.append(1 if index.step is None else index.step) axes.append(i) i = i + 1 @@ -1168,6 +1346,7 @@ class TorchFXImporter: stride.append(1) axes.append(i) i += 1 + print([type(x) for x in axes], [type(x) for x in begin], [type(x) for x in end], [type(x) for x in stride]) sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) sliced_shape = list(self.shape_of(sliced)) for i in expand_dim: @@ -1181,7 +1360,6 @@ class TorchFXImporter: dtype = x.struct_info.dtype return relax.const(x.data.numpy()[node.args[1]], dtype) else: - import IPython;IPython.embed() raise ValueError(f"Unsupported type {type(x)} for _getitem, should be list, tuple, ShapeExpr, Tuple, Var, or Constant") def create_convert_map(self): @@ -1193,6 +1371,7 @@ class TorchFXImporter: nn.Linear: self._linear, nn.Conv1d: self._conv1d, nn.Conv2d: self._conv2d, + nn.ConvTranspose2d: self._convtranpose2d, nn.MaxPool2d: self._max_pool2d, nn.AvgPool2d: self._avg_pool2d, nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True), @@ -1218,6 +1397,7 @@ class TorchFXImporter: "iadd": self._add, "add": self._add, "floordiv": self._floordiv, + "floor": self._floor, "mul": self._mul, "sub": self._sub, "pow": self._pow, @@ -1226,6 +1406,7 @@ class TorchFXImporter: "round": self._round, "lt": self._lt, "eq": self._eq, + "ne": self._ne, "truediv": self._truediv, "fill_": self._inplace_fill, "new_ones": self._new_ones, @@ -1265,6 +1446,8 @@ class TorchFXImporter: "unsqueeze": lambda node: self.block_builder.emit( relax.op.expand_dims(self.env[node.args[0]], node.args[1]) ), + "repeat_interleave": self._repeat_interleave, + "stack": self._stack, "view": self._reshape, "argmax": self._argmax_argmin(relax.op.argmax), "argmin": self._argmax_argmin(relax.op.argmin), @@ -1298,6 +1481,7 @@ class TorchFXImporter: "pad": self._pad, "unbind": self._unbind, "einsum": self._einsum, + "expand_as": self._expand_as, } def from_fx( diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index c3d38db4e1..d3bb34d21a 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -312,6 +312,5 @@ TVM_REGISTER_OP("relax.dynamic_strided_slice") .add_argument("end", "Tensor", "Indices indicating end of the slice.") .add_argument("strides", "Tensor", "The stride values.") .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDynStridedSlice); - } // namespace relax } // namespace tvm
