This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6732a9e3b2 [Relay] Implement `SoftmaxRel` for softmax operators.
(#11728)
6732a9e3b2 is described below
commit 6732a9e3b2d64316926693e91d5ca6a54fc75958
Author: WANG Zihan <[email protected]>
AuthorDate: Thu Jun 16 16:45:56 2022 +0800
[Relay] Implement `SoftmaxRel` for softmax operators. (#11728)
* Implement `SoftmaxRel` for softmax operators.
* Print better error message for wrong axis.
---
src/relay/op/nn/nn.cc | 27 ++++++++++++++++++++++++---
1 file changed, 24 insertions(+), 3 deletions(-)
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 41b47401de..b8d48d9e9e 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -399,6 +399,27 @@ where :math:`*` is an channelwise multiplication for each
sample in the batch.
// relay.softmax
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
+bool SoftmaxRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ ICHECK_EQ(types.size(), 2);
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) return false;
+
+ const SoftmaxAttrs* param = attrs.as<SoftmaxAttrs>();
+ ICHECK(param != nullptr);
+ int axis = param->axis;
+ int ndim = static_cast<int>(data->shape.size());
+ if (axis >= ndim || axis < -ndim) {
+ reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+ << "Wrong axis (" << axis << ") not in
expected range: ["
+ << -ndim << ", " << ndim << ")");
+ return false;
+ }
+
+ reporter->Assign(types[1], types[0]);
+ return true;
+}
+
TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax").set_body_typed([](Expr data,
int axis) {
auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis;
@@ -420,7 +441,7 @@ RELAY_REGISTER_OP("nn.softmax")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
- .add_type_rel("Identity", IdentityRel);
+ .add_type_rel("Softmax", SoftmaxRel);
// relay.fast_softmax
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
@@ -447,7 +468,7 @@ RELAY_REGISTER_OP("nn.fast_softmax")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
- .add_type_rel("Identity", IdentityRel);
+ .add_type_rel("Softmax", SoftmaxRel);
// relay.nn.log_softmax
TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax").set_body_typed([](Expr
data, int axis) {
@@ -471,7 +492,7 @@ RELAY_REGISTER_OP("nn.log_softmax")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
- .add_type_rel("Identity", IdentityRel)
+ .add_type_rel("Softmax", SoftmaxRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const
Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<SoftmaxAttrs>();