This is an automated email from the ASF dual-hosted git repository. tkonolige pushed a commit to branch tkonolige/relax_pad_etc_new in repository https://gitbox.apache.org/repos/asf/tvm.git
commit f62d3277abe093c3923099579bec0b2e2c0a88ee Author: Tristan Konolige <[email protected]> AuthorDate: Thu May 11 16:23:18 2023 +0000 Add einsum, gelu, pad support to relax --- include/tvm/relax/attrs/nn.h | 32 +++ python/tvm/relax/frontend/torch/fx_translator.py | 101 ++++++++- python/tvm/relax/op/nn/nn.py | 28 ++- python/tvm/relax/transform/legalize_ops/nn.py | 6 + src/relax/op/nn/pad.cc | 250 +++++++++++++++++++++++ src/relax/op/op_common.cc | 2 +- src/tir/op/op.cc | 2 +- 7 files changed, 415 insertions(+), 6 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index c1ca468fc9..8c5666d62b 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -320,6 +320,38 @@ struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> { } }; // struct AttentionAttrs +/*! \brief Attributes used for the padding operator */ +struct PadAttrs : public tvm::AttrsNode<PadAttrs> { + Array<Array<Integer>> pad_width; + tvm::String pad_mode; + + TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") { + TVM_ATTR_FIELD(pad_width).describe( + "Number of values padded to the edges of each axis, " + "in the format of ((before_1, after_1), ..., (before_N, after_N))"); + TVM_ATTR_FIELD(pad_mode) + .set_default("constant") + .describe( + "Padding type to use. \"constant\" pads with constant_value, " + "\"edge\" pads using the edge values of the input array, " + "\"reflect\" pads by reflecting values with respect to the edges."); + } +}; + +/*! \brief Attributes used for the MirrorPadding operator */ +struct MirrorPadAttrs : public tvm::AttrsNode<MirrorPadAttrs> { + std::string mode; + Array<Array<PrimExpr>> pad_width; + + TVM_DECLARE_ATTRS(MirrorPadAttrs, "relay.attrs.MirrorPadAttrs") { + TVM_ATTR_FIELD(mode) + .set_default("SYMMETRIC") + .describe("Specifies how mirroring should be performed."); + TVM_ATTR_FIELD(pad_width).describe( + "Number of values padded to the edges of each axis, " + "in the format of ((before_1, after_1), ..., (before_N, after_N))"); + } +}; } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a29070a325..7166b312a0 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -194,6 +194,10 @@ class TorchFXImporter: lhs, rhs = self.retrieve_args(node) if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var): return self._call_binary_op(relax.op.multiply, lhs, rhs) + if isinstance(rhs, (int, float)): + rhs = relax.const(rhs) + if isinstance(lhs, (int, float)): + lhs = relax.const(lhs) return lhs * rhs def _pow(self, node: fx.node.Node) -> relax.Expr: @@ -417,6 +421,18 @@ class TorchFXImporter: def _half(self, node: fx.node.Node) -> relax.Var: return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + def _long(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "int64")) + + def _int(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "int32")) + + def _short(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "int16")) + + def _int8(self, node: fx.node.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "int8")) + def _type(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) @@ -576,6 +592,10 @@ class TorchFXImporter: return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) + def _einsum(self, node: fx.node.Node) -> relax.Var: + operands = [self.env[x] for x in node.args[1:]] + return self.block_builder.emit(relax.op.einsum(operands, node.args[0])) + def _index_select(self, node: fx.node.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] @@ -1002,6 +1022,7 @@ class TorchFXImporter: ) ) +<<<<<<< HEAD def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var: assert len(node.args) <= 4, "Dropout, and causal masking are not supported." transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) @@ -1019,6 +1040,62 @@ class TorchFXImporter: return self.block_builder.emit(attn) +||||||| parent of 97e234777 (Add einsum, gelu, pad support to relax) +======= + def _pad_common(self, mode, pad_value, inputs): + data = self.env[inputs[0]] + pad_list = inputs[1] + + # initialize paddings based on input len + pad_len = len(data.struct_info.shape) * 2 + paddings = [0] * pad_len + + if len(pad_list) >= 2: + paddings[-1] = pad_list[1] + paddings[-2] = pad_list[0] + if len(pad_list) >= 4: + paddings[-3] = pad_list[3] + paddings[-4] = pad_list[2] + if len(pad_list) >= 6: + paddings[-5] = pad_list[5] + paddings[-6] = pad_list[4] + + # group into tuple of 2 ints + paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)] + + if mode == "constant": + return self.block_builder.emit(relax.op.nn.pad(data, paddings, pad_value=relax.const(pad_value), pad_mode=mode)) + else: + return self.block_builder.emit(relax.op.nn.pad(data, paddings, pad_mode=mode)) + + def _pad(self, node: fx.node.Node) -> relax.Expr: + if len(node.args) > 2 and node.args[2] is not None: + mode = node.args[2] + else: + mode = 'constant' + + if len(node.args) == 4 and node.args[3] is not None: + pad_value = node.args[3] + else: + pad_value = 0 + return self._pad_common(mode, pad_value, node.args) + + def _unbind(self, node: fx.node.Node) -> relax.Expr: + data = self.env[node.args[0]] + shape = data.struct_info.shape + if len(node.args) == 2 and node.args[1] is not None: + axis = node.args[1] + else: + axis = 0 + + selections = shape[axis] + res_split = relax.op.split(data, selections, axis) + ret = [] + for i in range(selections.value): + ret.append(relax.op.squeeze(res_split[i], axis=[axis])) + return self.block_builder.emit(relax.expr.Tuple(ret)) + + ########## Others ########## def _size(self, node: fx.node.Node) -> relax.Expr: @@ -1040,6 +1117,8 @@ class TorchFXImporter: return getattr(self.env[node.args[0]], node.args[1]) def _getitem(self, node: fx.node.Node) -> relax.Var: + from torch import fx + x = self.env[node.args[0]] if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)): return x[node.args[1]] @@ -1095,10 +1174,15 @@ class TorchFXImporter: sliced_shape.insert(i, 1) return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) elif isinstance(x, relax.Constant): - dtype = x.struct_info.dtype - return relax.const(x.data.numpy()[node.args[1]], dtype) + if isinstance(node.args[1], fx.node.Node): + idx = self.env[node.args[1]] + return self.block_builder.emit(relax.op.take(x,idx,axis=0)) + else: + dtype = x.struct_info.dtype + return relax.const(x.data.numpy()[node.args[1]], dtype) else: - assert False + import IPython;IPython.embed() + raise ValueError(f"Unsupported type {type(x)} for _getitem, should be list, tuple, ShapeExpr, Tuple, Var, or Constant") def create_convert_map(self): from torch import nn @@ -1118,6 +1202,7 @@ class TorchFXImporter: relax.op.clip(self.env[node.args[0]], 0, 6) ), nn.SiLU: lambda node: self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])), + nn.GELU: lambda node: self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])), nn.Flatten: self._flatten, nn.BatchNorm2d: self._batch_norm_2d, nn.LayerNorm: self._layer_norm, @@ -1154,6 +1239,13 @@ class TorchFXImporter: "sum": self._sum, "float": self._float, "half": self._half, + "long": self._long, + "int64": self._long, + "int": self._int, + "int32": self._int, + "short": self._short, + "int16": self._short, + "int8": self._int8, "type": self._type, "astype": self._type, "matmul": self._matmul, @@ -1203,6 +1295,9 @@ class TorchFXImporter: "max": self._max, "cross_entropy": self._cross_entropy, "scaled_dot_product_attention": self._scaled_dot_product_attention, + "pad": self._pad, + "unbind": self._unbind, + "einsum": self._einsum, } def from_fx( diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index fb5e0736ff..5c58cc5966 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Relax Neural Network (NN) operators""" -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Sequence from tvm import DataType from tvm.tir import FloatImm @@ -1022,3 +1022,29 @@ def attention( (batch_size, seq_len, num_head, head_dim_v). """ return _ffi_api.attention(query, key, value, bias, scale) # type: ignore + +def pad(data: Expr, pad_width: Sequence[Sequence[int]], pad_value: Union[Expr, float]=0., pad_mode: str="constant"): + r"""Padding + + This operator takes in a tensor and pads each axis by the specified + widths using the specified value. + + Parameters + ---------- + data: relax.Expr + The input data to the operator + pad_width: tuple of <tuple of <int>>, required + Number of values padded to the edges of each axis, in the format + of ((before_1, after_1), ..., (before_N, after_N)) + pad_value: float, or tvm.relay.Expr, optional, default=0 + The value used for padding + pad_mode: 'constant', 'edge', 'reflect' + 'constant' pads with constant_value pad_value + 'edge' pads using the edge values of the input array + 'reflect' pads by reflecting values with respect to the edge + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.pad(data, pad_width, pad_value, pad_mode) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 9c98682e32..b362660561 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -397,3 +397,9 @@ def _nn_nll_loss(bb: BlockBuilder, call: Call) -> Expr: reduction=call.attrs.reduction, ignore_index=call.attrs.ignore_index, ) + +@register_legalize("relax.nn.pad") +def _pad(bb: BlockBuilder, call: Call) -> Expr: + assert call.attrs.pad_mode == "constant" + # return bb.call_te(topi.nn.pad, call.args[0], [x[0] for x in call.attrs.pad_width], [x[1] for x in call.attrs.pad_width], call.args[1]) + return bb.call_te(topi.nn.pad, call.args[0], [x[0] for x in call.attrs.pad_width], [x[1] for x in call.attrs.pad_width], tir.const(call.args[1].data.numpy()[()])) diff --git a/src/relax/op/nn/pad.cc b/src/relax/op/nn/pad.cc new file mode 100644 index 0000000000..224ff7cb51 --- /dev/null +++ b/src/relax/op/nn/pad.cc @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file pad.cc + * \brief Implementation of operator pad + */ +#include <tvm/relax/attrs/nn.h> +#include <tvm/tir/data_layout.h> +#include <tvm/tir/op.h> +#include <tvm/topi/elemwise.h> +#include <tvm/topi/nn.h> + +#include <vector> + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +// relax.nn.pad +TVM_REGISTER_NODE_TYPE(PadAttrs); + +InferLayoutOutput PadInferLayout(const Call& call, + const Map<String, Array<String>>& desired_layouts, + const VarLayoutMap& var_layout_map) { + + LayoutDecision exisiting_layout = GetLayoutDecision(var_layout_map, call->args[0]); + return InferLayoutOutput({exisiting_layout}, + {exisiting_layout}, + Attrs(call->attrs)); + // TODO: handle layout changes for pad + // const auto* attrs_ptr = call->attrs.as<PadAttrs>(); + // CHECK(attrs_ptr); + // ObjectPtr<PadAttrs> params = make_object<PadAttrs>(*attrs_ptr); + // + // Layout ret_data; + // // If new_in_layouts are defined, this code tries to modify the layout. + // bool is_layout_modified = new_in_layouts.defined(); + // if (new_in_layouts.defined()) { + // // Create a map of axis to param_width. For the new layout, a new param_width is generated using + // // the map. The new layout is rejected, if the padding is happening along the axis which was + // // split. + // + // // 1) Create a map from axis to param_width using old layout. + // std::map<std::string, tvm::Array<Integer>> axis_pad_width; + // int index_counter = 0; + // ICHECK_EQ(new_in_layouts.size(), 2); + // ICHECK_EQ(old_in_layouts.size(), 2); + // for (auto iter_var : old_in_layouts[0]->axes) { + // const auto& old_layout_axis = LayoutAxis::Get(iter_var); + // axis_pad_width.emplace(old_layout_axis.name(), params->pad_width[index_counter]); + // index_counter++; + // } + // + // // 2) Create new pad width by walking over the new layout and using the map. + // tvm::Array<tvm::Array<Integer>> new_pad_width; + // for (auto iter_var : new_in_layouts[0]->axes) { + // const auto& new_layout_axis = LayoutAxis::Get(iter_var); + // auto axis_name = new_layout_axis.name(); + // if (axis_pad_width.count(axis_name) != 0 && new_layout_axis.IsPrimal()) { + // // This is primal axis. So, directly use the original pad_width. + // new_pad_width.push_back(axis_pad_width.at(axis_name)); + // } else { + // // This is the axis that got split. So, check that pad_width was [0, 0] originally. + // const auto& dual_axis = new_layout_axis.ToPrimal(); + // auto dual_axis_name = dual_axis.name(); + // ICHECK(axis_pad_width.count(dual_axis_name)) + // << "Missing axis " << dual_axis << " in " << old_in_layouts[0].name(); + // new_pad_width.push_back(axis_pad_width.at(dual_axis_name)); + // + // // If any pad_width element is not zero, do not change the layout. + // for (auto width : axis_pad_width.at(dual_axis_name)) { + // if (auto* width_imm = width.as<IntImmNode>()) { + // if (width_imm->value != 0) { + // is_layout_modified = false; + // } + // } else { + // is_layout_modified = false; + // } + // } + // } + // } + // + // // If the above conditions satisfied, we can set the newly created pad_width and use the new + // // layout. + // if (is_layout_modified) { + // ret_data = new_in_layouts[0]; + // params->pad_width = new_pad_width; + // } + // } + // + // if (!is_layout_modified) { + // if (old_in_layouts.defined()) { + // ICHECK_EQ(old_in_layouts.size(), 2); + // ret_data = old_in_layouts[0]; + // } else { + // ret_data = Layout::Undef(); + // } + // } + // + // // The pad value is always a scalar + // Layout ret_pad_value = Layout("1"); + // return InferLayoutOutput({ret_data, ret_pad_value}, {ret_data}, Attrs(params)); +} + +StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { + auto infos = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = infos[0]; + const auto* attrs = call->attrs.as<PadAttrs>(); + + const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>(); + ICHECK_EQ(data_shape->values.size(), attrs->pad_width.size()) << "Data shape and padding should be the same size"; + + Array<PrimExpr> padded_shape; + for(size_t i = 0; i < data_shape->values.size(); i++) { + padded_shape.push_back(data_shape->values[i] + attrs->pad_width[i][0] + attrs->pad_width[i][1]); + } + return TensorStructInfo(ShapeExpr(padded_shape), data_sinfo->dtype); +} + +Array<te::Tensor> PadCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, + const Type& out_type) { + const auto* param = attrs.as<PadAttrs>(); + ICHECK(param != nullptr); + + auto pad_width = param->pad_width; + ICHECK(pad_width.size() == inputs[0].ndim() && pad_width[0].size() == 2) << "Illegal pad_width"; + Array<PrimExpr> pad_before; + for (size_t i = 0; i < pad_width.size(); ++i) { + pad_before.push_back(pad_width[i][0]); + } + Array<PrimExpr> pad_after; + for (size_t i = 0; i < pad_width.size(); ++i) { + pad_after.push_back(pad_width[i][1]); + } + te::Tensor cast_pad_value = topi::cast(inputs[1], inputs[0]->dtype); + const PrimExpr& pad_value = cast_pad_value(Array<PrimExpr>(inputs[1]->shape.size(), 0)); + return Array<te::Tensor>{topi::pad(inputs[0], pad_before, pad_after, pad_value, "T_pad", + topi::kElementWise, param->pad_mode)}; +} + +// Handler to create a call to the padding op used by front-end FFI +Expr MakePad(Expr data, Array<Array<Integer>> pad_width, Expr pad_value, String pad_mode) { + auto attrs = make_object<PadAttrs>(); + attrs->pad_width = std::move(pad_width); + attrs->pad_mode = std::move(pad_mode); + static const Op& op = Op::Get("relax.nn.pad"); + return Call(op, {data, pad_value}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(MakePad); + +TVM_REGISTER_OP("relax.nn.pad") + .describe(R"code(Pad for n-D tensor. + +)code" TVM_ADD_FILELINE) + .set_attrs_type<PadAttrs>() + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("pad_val", "Tensor", "The value to fill the padded area with") + .set_support_level(2) + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPad) + .set_attr<FRelaxInferLayout>("FRelaxInferLayout", PadInferLayout); + // .set_attr<TOpPattern>("TOpPattern", kInjective) + // .set_attr<FTVMCompute>("FTVMCompute", PadCompute); + +// relax.nn.mirror_pad +TVM_REGISTER_NODE_TYPE(MirrorPadAttrs); + +bool MirrorPadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + ICHECK_EQ(types.size(), 2); + const auto* data = types[0].as<TensorTypeNode>(); + if (data == nullptr) return false; + + const MirrorPadAttrs* param = attrs.as<MirrorPadAttrs>(); + ICHECK(param != nullptr); + + // check that pad widths match lengths + ICHECK(data->shape.size() == param->pad_width.size()) + << "There should be as many pad width pairs as shape dimensions " + << "but the shape has " << data->shape.size() << " dimensions " + << "and there are " << param->pad_width.size() << " pad width pairs."; + + // each pad width element should be a pair of positive integers + std::vector<PrimExpr> oshape; + for (size_t i = 0; i < param->pad_width.size(); i++) { + ICHECK(param->pad_width[i].size() == 2) + << "Each pad width element should be a pair but at index " << i << " there are " + << param->pad_width[i].size() << " elements."; + + auto width1 = tir::as_const_int(param->pad_width[i][0]); + auto width2 = tir::as_const_int(param->pad_width[i][1]); + ICHECK(width1 != nullptr); + ICHECK(width2 != nullptr); + + ICHECK(*width1 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width1 << "."; + ICHECK(*width2 >= 0) << "Param width elements should be positive but first pad width at " + << "index " << i << " is " << *width2 << "."; + + auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2); + oshape.push_back(data->shape[i] + padding); + } + + reporter->Assign(types[1], TensorType(Array<PrimExpr>(oshape), data->dtype)); + return true; +} + +// Handler to create a call to the padding op used by front-end FFI +Expr MakeMirrorPad(Expr data, Array<Array<PrimExpr>> pad_width, String mode) { + auto attrs = make_object<MirrorPadAttrs>(); + attrs->mode = mode; + attrs->pad_width = std::move(pad_width); + static const Op& op = Op::Get("nn.mirror_pad"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn._make.mirror_pad").set_body_typed(MakeMirrorPad); + +TVM_REGISTER_OP("nn.mirror_pad") + .describe(R"code(MirrorPad for n-D tensor. + +)code" TVM_ADD_FILELINE) + .set_attrs_type<MirrorPadAttrs>() + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_support_level(2); + // .add_type_rel("MirrorPad", MirrorPadRel) + // .set_attr<TOpPattern>("TOpPattern", kInjective); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index 7428fd66fc..fa550f415d 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -29,7 +29,7 @@ Array<TensorStructInfo> GetInputTensorStructInfo(const Call& call, const BlockBu int n_input = op->arguments.size(); if (static_cast<int>(call->args.size()) != n_input) { ctx->ReportFatal(Diagnostic::Error(call) - << op << " op should have " << n_input << " arguments"); + << op << " op should have " << n_input << " arguments" << " but it got " << call->args.size()); } Array<TensorStructInfo> input_tensor_sinfo; input_tensor_sinfo.reserve(n_input); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 4439a9c3d7..4b3a0f5a7a 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -937,7 +937,7 @@ TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { } else if (args[0].type_code() == kDLFloat) { *ret = tir::make_const(args[1], args[0].operator double(), args[2]); } else { - LOG(FATAL) << "only accept int or float"; // FIXME + LOG(FATAL) << "Constant must be int or float but got " << tvm::runtime::ArgTypeCode2Str(args[0].type_code()); // FIXME } });
