This is an automated email from the ASF dual-hosted git repository.

tkonolige pushed a commit to branch tkonolige/relax_pad_etc_new
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit f62d3277abe093c3923099579bec0b2e2c0a88ee
Author: Tristan Konolige <[email protected]>
AuthorDate: Thu May 11 16:23:18 2023 +0000

    Add einsum, gelu, pad support to relax
---
 include/tvm/relax/attrs/nn.h                     |  32 +++
 python/tvm/relax/frontend/torch/fx_translator.py | 101 ++++++++-
 python/tvm/relax/op/nn/nn.py                     |  28 ++-
 python/tvm/relax/transform/legalize_ops/nn.py    |   6 +
 src/relax/op/nn/pad.cc                           | 250 +++++++++++++++++++++++
 src/relax/op/op_common.cc                        |   2 +-
 src/tir/op/op.cc                                 |   2 +-
 7 files changed, 415 insertions(+), 6 deletions(-)

diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index c1ca468fc9..8c5666d62b 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -320,6 +320,38 @@ struct AttentionAttrs : public 
tvm::AttrsNode<AttentionAttrs> {
   }
 };  // struct AttentionAttrs
 
+/*! \brief Attributes used for the padding operator */
+struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
+  Array<Array<Integer>> pad_width;
+  tvm::String pad_mode;
+
+  TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
+    TVM_ATTR_FIELD(pad_width).describe(
+        "Number of values padded to the edges of each axis, "
+        "in the format of ((before_1, after_1), ..., (before_N, after_N))");
+    TVM_ATTR_FIELD(pad_mode)
+        .set_default("constant")
+        .describe(
+            "Padding type to use. \"constant\" pads with constant_value, "
+            "\"edge\" pads using the edge values of the input array, "
+            "\"reflect\" pads by reflecting values with respect to the 
edges.");
+  }
+};
+
+/*! \brief Attributes used for the MirrorPadding operator */
+struct MirrorPadAttrs : public tvm::AttrsNode<MirrorPadAttrs> {
+  std::string mode;
+  Array<Array<PrimExpr>> pad_width;
+
+  TVM_DECLARE_ATTRS(MirrorPadAttrs, "relay.attrs.MirrorPadAttrs") {
+    TVM_ATTR_FIELD(mode)
+        .set_default("SYMMETRIC")
+        .describe("Specifies how mirroring should be performed.");
+    TVM_ATTR_FIELD(pad_width).describe(
+        "Number of values padded to the edges of each axis, "
+        "in the format of ((before_1, after_1), ..., (before_N, after_N))");
+  }
+};
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index a29070a325..7166b312a0 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -194,6 +194,10 @@ class TorchFXImporter:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
             return self._call_binary_op(relax.op.multiply, lhs, rhs)
+        if isinstance(rhs, (int, float)):
+            rhs = relax.const(rhs)
+        if isinstance(lhs, (int, float)):
+            lhs = relax.const(lhs)
         return lhs * rhs
 
     def _pow(self, node: fx.node.Node) -> relax.Expr:
@@ -417,6 +421,18 @@ class TorchFXImporter:
     def _half(self, node: fx.node.Node) -> relax.Var:
         return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"float16"))
 
+    def _long(self, node: fx.node.Node) -> relax.Var:
+        return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"int64"))
+
+    def _int(self, node: fx.node.Node) -> relax.Var:
+        return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"int32"))
+
+    def _short(self, node: fx.node.Node) -> relax.Var:
+        return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"int16"))
+
+    def _int8(self, node: fx.node.Node) -> relax.Var:
+        return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], 
"int8"))
+
     def _type(self, node: fx.node.Node) -> relax.Var:
         x = self.env[node.args[0]]
         dtype = TorchFXImporter._convert_data_type(node.args[1], self.env)
@@ -576,6 +592,10 @@ class TorchFXImporter:
 
         return self.block_builder.emit(relax.op.cumsum(x, dim, dtype))
 
+    def _einsum(self, node: fx.node.Node) -> relax.Var:
+        operands = [self.env[x] for x in node.args[1:]]
+        return self.block_builder.emit(relax.op.einsum(operands, node.args[0]))
+
     def _index_select(self, node: fx.node.Node) -> relax.Var:
         x = self.env[node.args[0]]
         dim = node.args[1]
@@ -1002,6 +1022,7 @@ class TorchFXImporter:
             )
         )
 
+<<<<<<< HEAD
     def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var:
         assert len(node.args) <= 4, "Dropout, and causal masking are not 
supported."
         transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 
3])
@@ -1019,6 +1040,62 @@ class TorchFXImporter:
 
         return self.block_builder.emit(attn)
 
+||||||| parent of 97e234777 (Add einsum, gelu, pad support to relax)
+=======
+    def _pad_common(self, mode, pad_value, inputs):
+        data = self.env[inputs[0]]
+        pad_list = inputs[1]
+
+        # initialize paddings based on input len
+        pad_len = len(data.struct_info.shape) * 2
+        paddings = [0] * pad_len
+
+        if len(pad_list) >= 2:
+            paddings[-1] = pad_list[1]
+            paddings[-2] = pad_list[0]
+        if len(pad_list) >= 4:
+            paddings[-3] = pad_list[3]
+            paddings[-4] = pad_list[2]
+        if len(pad_list) >= 6:
+            paddings[-5] = pad_list[5]
+            paddings[-6] = pad_list[4]
+
+        # group into tuple of 2 ints
+        paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)]
+
+        if mode == "constant":
+            return self.block_builder.emit(relax.op.nn.pad(data, paddings, 
pad_value=relax.const(pad_value), pad_mode=mode))
+        else:
+            return self.block_builder.emit(relax.op.nn.pad(data, paddings, 
pad_mode=mode))
+
+    def _pad(self, node: fx.node.Node) -> relax.Expr:
+        if len(node.args) > 2 and node.args[2] is not None:
+            mode = node.args[2]
+        else:
+            mode = 'constant'
+
+        if len(node.args) == 4 and node.args[3] is not None:
+            pad_value = node.args[3]
+        else:
+            pad_value = 0
+        return self._pad_common(mode, pad_value, node.args)
+
+    def _unbind(self, node: fx.node.Node) -> relax.Expr:
+        data = self.env[node.args[0]]
+        shape = data.struct_info.shape
+        if len(node.args) == 2 and node.args[1] is not None:
+            axis = node.args[1]
+        else:
+            axis = 0
+
+        selections = shape[axis]
+        res_split = relax.op.split(data, selections, axis)
+        ret = []
+        for i in range(selections.value):
+            ret.append(relax.op.squeeze(res_split[i], axis=[axis]))
+        return self.block_builder.emit(relax.expr.Tuple(ret))
+
+
     ########## Others ##########
 
     def _size(self, node: fx.node.Node) -> relax.Expr:
@@ -1040,6 +1117,8 @@ class TorchFXImporter:
         return getattr(self.env[node.args[0]], node.args[1])
 
     def _getitem(self, node: fx.node.Node) -> relax.Var:
+        from torch import fx
+
         x = self.env[node.args[0]]
         if isinstance(x, (list, tuple, relax.ShapeExpr, relax.Tuple)):
             return x[node.args[1]]
@@ -1095,10 +1174,15 @@ class TorchFXImporter:
                 sliced_shape.insert(i, 1)
             return self.block_builder.emit(relax.op.reshape(sliced, 
sliced_shape))
         elif isinstance(x, relax.Constant):
-            dtype = x.struct_info.dtype
-            return relax.const(x.data.numpy()[node.args[1]], dtype)
+            if isinstance(node.args[1], fx.node.Node):
+                idx = self.env[node.args[1]]
+                return self.block_builder.emit(relax.op.take(x,idx,axis=0))
+            else:
+                dtype = x.struct_info.dtype
+                return relax.const(x.data.numpy()[node.args[1]], dtype)
         else:
-            assert False
+            import IPython;IPython.embed()
+            raise ValueError(f"Unsupported type {type(x)} for _getitem, should 
be list, tuple, ShapeExpr, Tuple, Var, or Constant")
 
     def create_convert_map(self):
         from torch import nn
@@ -1118,6 +1202,7 @@ class TorchFXImporter:
                 relax.op.clip(self.env[node.args[0]], 0, 6)
             ),
             nn.SiLU: lambda node: 
self.block_builder.emit(relax.op.nn.silu(self.env[node.args[0]])),
+            nn.GELU: lambda node: 
self.block_builder.emit(relax.op.nn.gelu(self.env[node.args[0]])),
             nn.Flatten: self._flatten,
             nn.BatchNorm2d: self._batch_norm_2d,
             nn.LayerNorm: self._layer_norm,
@@ -1154,6 +1239,13 @@ class TorchFXImporter:
             "sum": self._sum,
             "float": self._float,
             "half": self._half,
+            "long": self._long,
+            "int64": self._long,
+            "int": self._int,
+            "int32": self._int,
+            "short": self._short,
+            "int16": self._short,
+            "int8": self._int8,
             "type": self._type,
             "astype": self._type,
             "matmul": self._matmul,
@@ -1203,6 +1295,9 @@ class TorchFXImporter:
             "max": self._max,
             "cross_entropy": self._cross_entropy,
             "scaled_dot_product_attention": self._scaled_dot_product_attention,
+            "pad": self._pad,
+            "unbind": self._unbind,
+            "einsum": self._einsum,
         }
 
     def from_fx(
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index fb5e0736ff..5c58cc5966 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 """Relax Neural Network (NN) operators"""
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union, Sequence
 
 from tvm import DataType
 from tvm.tir import FloatImm
@@ -1022,3 +1022,29 @@ def attention(
         (batch_size, seq_len, num_head, head_dim_v).
     """
     return _ffi_api.attention(query, key, value, bias, scale)  # type: ignore
+
+def pad(data: Expr, pad_width: Sequence[Sequence[int]], pad_value: Union[Expr, 
float]=0., pad_mode: str="constant"):
+    r"""Padding
+
+    This operator takes in a tensor and pads each axis by the specified
+    widths using the specified value.
+
+    Parameters
+    ----------
+    data: relax.Expr
+        The input data to the operator
+    pad_width: tuple of <tuple of <int>>, required
+        Number of values padded to the edges of each axis, in the format
+        of ((before_1, after_1), ..., (before_N, after_N))
+    pad_value: float, or tvm.relay.Expr, optional, default=0
+        The value used for padding
+    pad_mode: 'constant', 'edge', 'reflect'
+        'constant' pads with constant_value pad_value
+        'edge' pads using the edge values of the input array
+        'reflect' pads by reflecting values with respect to the edge
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.pad(data, pad_width, pad_value, pad_mode)
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 9c98682e32..b362660561 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -397,3 +397,9 @@ def _nn_nll_loss(bb: BlockBuilder, call: Call) -> Expr:
         reduction=call.attrs.reduction,
         ignore_index=call.attrs.ignore_index,
     )
+
+@register_legalize("relax.nn.pad")
+def _pad(bb: BlockBuilder, call: Call) -> Expr:
+    assert call.attrs.pad_mode == "constant"
+    # return bb.call_te(topi.nn.pad, call.args[0], [x[0] for x in 
call.attrs.pad_width], [x[1] for x in call.attrs.pad_width], call.args[1])
+    return bb.call_te(topi.nn.pad, call.args[0], [x[0] for x in 
call.attrs.pad_width], [x[1] for x in call.attrs.pad_width], 
tir.const(call.args[1].data.numpy()[()]))
diff --git a/src/relax/op/nn/pad.cc b/src/relax/op/nn/pad.cc
new file mode 100644
index 0000000000..224ff7cb51
--- /dev/null
+++ b/src/relax/op/nn/pad.cc
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file pad.cc
+ * \brief Implementation of operator pad
+ */
+#include <tvm/relax/attrs/nn.h>
+#include <tvm/tir/data_layout.h>
+#include <tvm/tir/op.h>
+#include <tvm/topi/elemwise.h>
+#include <tvm/topi/nn.h>
+
+#include <vector>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+// relax.nn.pad
+TVM_REGISTER_NODE_TYPE(PadAttrs);
+
+InferLayoutOutput PadInferLayout(const Call& call, 
+                                         const Map<String, Array<String>>& 
desired_layouts,
+                                         const VarLayoutMap& var_layout_map) {
+
+  LayoutDecision exisiting_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  return InferLayoutOutput({exisiting_layout},
+                           {exisiting_layout},
+                           Attrs(call->attrs));
+  // TODO: handle layout changes for pad
+  // const auto* attrs_ptr = call->attrs.as<PadAttrs>();
+  // CHECK(attrs_ptr);
+  // ObjectPtr<PadAttrs> params = make_object<PadAttrs>(*attrs_ptr);
+  //
+  // Layout ret_data;
+  // // If new_in_layouts are defined, this code tries to modify the layout.
+  // bool is_layout_modified = new_in_layouts.defined();
+  // if (new_in_layouts.defined()) {
+  //   // Create a map of axis to param_width. For the new layout, a new 
param_width is generated using
+  //   // the map. The new layout is rejected, if the padding is happening 
along the axis which was
+  //   // split.
+  //
+  //   // 1) Create a map from axis to param_width using old layout.
+  //   std::map<std::string, tvm::Array<Integer>> axis_pad_width;
+  //   int index_counter = 0;
+  //   ICHECK_EQ(new_in_layouts.size(), 2);
+  //   ICHECK_EQ(old_in_layouts.size(), 2);
+  //   for (auto iter_var : old_in_layouts[0]->axes) {
+  //     const auto& old_layout_axis = LayoutAxis::Get(iter_var);
+  //     axis_pad_width.emplace(old_layout_axis.name(), 
params->pad_width[index_counter]);
+  //     index_counter++;
+  //   }
+  //
+  //   // 2) Create new pad width by walking over the new layout and using the 
map.
+  //   tvm::Array<tvm::Array<Integer>> new_pad_width;
+  //   for (auto iter_var : new_in_layouts[0]->axes) {
+  //     const auto& new_layout_axis = LayoutAxis::Get(iter_var);
+  //     auto axis_name = new_layout_axis.name();
+  //     if (axis_pad_width.count(axis_name) != 0 && 
new_layout_axis.IsPrimal()) {
+  //       // This is primal axis. So, directly use the original pad_width.
+  //       new_pad_width.push_back(axis_pad_width.at(axis_name));
+  //     } else {
+  //       // This is the axis that got split. So, check that pad_width was 
[0, 0] originally.
+  //       const auto& dual_axis = new_layout_axis.ToPrimal();
+  //       auto dual_axis_name = dual_axis.name();
+  //       ICHECK(axis_pad_width.count(dual_axis_name))
+  //           << "Missing axis " << dual_axis << " in " << 
old_in_layouts[0].name();
+  //       new_pad_width.push_back(axis_pad_width.at(dual_axis_name));
+  //
+  //       // If any pad_width element is not zero, do not change the layout.
+  //       for (auto width : axis_pad_width.at(dual_axis_name)) {
+  //         if (auto* width_imm = width.as<IntImmNode>()) {
+  //           if (width_imm->value != 0) {
+  //             is_layout_modified = false;
+  //           }
+  //         } else {
+  //           is_layout_modified = false;
+  //         }
+  //       }
+  //     }
+  //   }
+  //
+  //   // If the above conditions satisfied, we can set the newly created 
pad_width and use the new
+  //   // layout.
+  //   if (is_layout_modified) {
+  //     ret_data = new_in_layouts[0];
+  //     params->pad_width = new_pad_width;
+  //   }
+  // }
+  //
+  // if (!is_layout_modified) {
+  //   if (old_in_layouts.defined()) {
+  //     ICHECK_EQ(old_in_layouts.size(), 2);
+  //     ret_data = old_in_layouts[0];
+  //   } else {
+  //     ret_data = Layout::Undef();
+  //   }
+  // }
+  //
+  // // The pad value is always a scalar
+  // Layout ret_pad_value = Layout("1");
+  // return InferLayoutOutput({ret_data, ret_pad_value}, {ret_data}, 
Attrs(params));
+}
+
+StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) {
+   auto infos = GetInputTensorStructInfo(call, ctx);
+  TensorStructInfo data_sinfo = infos[0];
+   const auto* attrs = call->attrs.as<PadAttrs>();
+
+   const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+   ICHECK_EQ(data_shape->values.size(), attrs->pad_width.size()) << "Data 
shape and padding should be the same size";
+
+   Array<PrimExpr> padded_shape;
+   for(size_t i = 0; i < data_shape->values.size(); i++) {
+     padded_shape.push_back(data_shape->values[i] + attrs->pad_width[i][0] + 
attrs->pad_width[i][1]);
+   }
+   return TensorStructInfo(ShapeExpr(padded_shape), data_sinfo->dtype);
+}
+
+Array<te::Tensor> PadCompute(const Attrs& attrs, const Array<te::Tensor>& 
inputs,
+                             const Type& out_type) {
+  const auto* param = attrs.as<PadAttrs>();
+  ICHECK(param != nullptr);
+
+  auto pad_width = param->pad_width;
+  ICHECK(pad_width.size() == inputs[0].ndim() && pad_width[0].size() == 2) << 
"Illegal pad_width";
+  Array<PrimExpr> pad_before;
+  for (size_t i = 0; i < pad_width.size(); ++i) {
+    pad_before.push_back(pad_width[i][0]);
+  }
+  Array<PrimExpr> pad_after;
+  for (size_t i = 0; i < pad_width.size(); ++i) {
+    pad_after.push_back(pad_width[i][1]);
+  }
+  te::Tensor cast_pad_value = topi::cast(inputs[1], inputs[0]->dtype);
+  const PrimExpr& pad_value = 
cast_pad_value(Array<PrimExpr>(inputs[1]->shape.size(), 0));
+  return Array<te::Tensor>{topi::pad(inputs[0], pad_before, pad_after, 
pad_value, "T_pad",
+                                     topi::kElementWise, param->pad_mode)};
+}
+
+// Handler to create a call to the padding op used by front-end FFI
+Expr MakePad(Expr data, Array<Array<Integer>> pad_width, Expr pad_value, 
String pad_mode) {
+  auto attrs = make_object<PadAttrs>();
+  attrs->pad_width = std::move(pad_width);
+  attrs->pad_mode = std::move(pad_mode);
+  static const Op& op = Op::Get("relax.nn.pad");
+  return Call(op, {data, pad_value}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(MakePad);
+
+TVM_REGISTER_OP("relax.nn.pad")
+    .describe(R"code(Pad for n-D tensor.
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<PadAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("pad_val", "Tensor", "The value to fill the padded area 
with")
+    .set_support_level(2)
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPad)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", PadInferLayout);
+    // .set_attr<TOpPattern>("TOpPattern", kInjective)
+    // .set_attr<FTVMCompute>("FTVMCompute", PadCompute);
+
+// relax.nn.mirror_pad
+TVM_REGISTER_NODE_TYPE(MirrorPadAttrs);
+
+bool MirrorPadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                  const TypeReporter& reporter) {
+  ICHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  const MirrorPadAttrs* param = attrs.as<MirrorPadAttrs>();
+  ICHECK(param != nullptr);
+
+  // check that pad widths match lengths
+  ICHECK(data->shape.size() == param->pad_width.size())
+      << "There should be as many pad width pairs as shape dimensions "
+      << "but the shape has " << data->shape.size() << " dimensions "
+      << "and there are " << param->pad_width.size() << " pad width pairs.";
+
+  // each pad width element should be a pair of positive integers
+  std::vector<PrimExpr> oshape;
+  for (size_t i = 0; i < param->pad_width.size(); i++) {
+    ICHECK(param->pad_width[i].size() == 2)
+        << "Each pad width element should be a pair but at index " << i << " 
there are "
+        << param->pad_width[i].size() << " elements.";
+
+    auto width1 = tir::as_const_int(param->pad_width[i][0]);
+    auto width2 = tir::as_const_int(param->pad_width[i][1]);
+    ICHECK(width1 != nullptr);
+    ICHECK(width2 != nullptr);
+
+    ICHECK(*width1 >= 0) << "Param width elements should be positive but first 
pad width at "
+                         << "index " << i << " is " << *width1 << ".";
+    ICHECK(*width2 >= 0) << "Param width elements should be positive but first 
pad width at "
+                         << "index " << i << " is " << *width2 << ".";
+
+    auto padding = tir::make_const(data->shape[i].dtype(), *width1 + *width2);
+    oshape.push_back(data->shape[i] + padding);
+  }
+
+  reporter->Assign(types[1], TensorType(Array<PrimExpr>(oshape), data->dtype));
+  return true;
+}
+
+// Handler to create a call to the padding op used by front-end FFI
+Expr MakeMirrorPad(Expr data, Array<Array<PrimExpr>> pad_width, String mode) {
+  auto attrs = make_object<MirrorPadAttrs>();
+  attrs->mode = mode;
+  attrs->pad_width = std::move(pad_width);
+  static const Op& op = Op::Get("nn.mirror_pad");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn._make.mirror_pad").set_body_typed(MakeMirrorPad);
+
+TVM_REGISTER_OP("nn.mirror_pad")
+    .describe(R"code(MirrorPad for n-D tensor.
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<MirrorPadAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_support_level(2);
+    // .add_type_rel("MirrorPad", MirrorPadRel)
+    // .set_attr<TOpPattern>("TOpPattern", kInjective);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc
index 7428fd66fc..fa550f415d 100644
--- a/src/relax/op/op_common.cc
+++ b/src/relax/op/op_common.cc
@@ -29,7 +29,7 @@ Array<TensorStructInfo> GetInputTensorStructInfo(const Call& 
call, const BlockBu
   int n_input = op->arguments.size();
   if (static_cast<int>(call->args.size()) != n_input) {
     ctx->ReportFatal(Diagnostic::Error(call)
-                     << op << " op should have " << n_input << " arguments");
+                     << op << " op should have " << n_input << " arguments" << 
" but it got " << call->args.size());
   }
   Array<TensorStructInfo> input_tensor_sinfo;
   input_tensor_sinfo.reserve(n_input);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 4439a9c3d7..4b3a0f5a7a 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -937,7 +937,7 @@ TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs 
args, TVMRetValue* ret) {
   } else if (args[0].type_code() == kDLFloat) {
     *ret = tir::make_const(args[1], args[0].operator double(), args[2]);
   } else {
-    LOG(FATAL) << "only accept int or float";  // FIXME
+    LOG(FATAL) << "Constant must be int or float but got " << 
tvm::runtime::ArgTypeCode2Str(args[0].type_code());  // FIXME
   }
 });
 

Reply via email to