This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6732a9e3b2 [Relay] Implement `SoftmaxRel` for softmax operators. 
(#11728)
6732a9e3b2 is described below

commit 6732a9e3b2d64316926693e91d5ca6a54fc75958
Author: WANG Zihan <[email protected]>
AuthorDate: Thu Jun 16 16:45:56 2022 +0800

    [Relay] Implement `SoftmaxRel` for softmax operators. (#11728)
    
    * Implement `SoftmaxRel` for softmax operators.
    
    * Print better error message for wrong axis.
---
 src/relay/op/nn/nn.cc | 27 ++++++++++++++++++++++++---
 1 file changed, 24 insertions(+), 3 deletions(-)

diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 41b47401de..b8d48d9e9e 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -399,6 +399,27 @@ where :math:`*` is an channelwise multiplication for each 
sample in the batch.
 // relay.softmax
 TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
 
+bool SoftmaxRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  ICHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  const SoftmaxAttrs* param = attrs.as<SoftmaxAttrs>();
+  ICHECK(param != nullptr);
+  int axis = param->axis;
+  int ndim = static_cast<int>(data->shape.size());
+  if (axis >= ndim || axis < -ndim) {
+    reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan())
+                                     << "Wrong axis (" << axis << ") not in 
expected range: ["
+                                     << -ndim << ", " << ndim << ")");
+    return false;
+  }
+
+  reporter->Assign(types[1], types[0]);
+  return true;
+}
+
 TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax").set_body_typed([](Expr data, 
int axis) {
   auto attrs = make_object<SoftmaxAttrs>();
   attrs->axis = axis;
@@ -420,7 +441,7 @@ RELAY_REGISTER_OP("nn.softmax")
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "The input tensor.")
     .set_support_level(1)
-    .add_type_rel("Identity", IdentityRel);
+    .add_type_rel("Softmax", SoftmaxRel);
 
 // relay.fast_softmax
 TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
@@ -447,7 +468,7 @@ RELAY_REGISTER_OP("nn.fast_softmax")
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "The input tensor.")
     .set_support_level(1)
-    .add_type_rel("Identity", IdentityRel);
+    .add_type_rel("Softmax", SoftmaxRel);
 
 // relay.nn.log_softmax
 TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax").set_body_typed([](Expr 
data, int axis) {
@@ -471,7 +492,7 @@ RELAY_REGISTER_OP("nn.log_softmax")
     .set_num_inputs(1)
     .add_argument("data", "Tensor", "The input tensor.")
     .set_support_level(1)
-    .add_type_rel("Identity", IdentityRel)
+    .add_type_rel("Softmax", SoftmaxRel)
     .set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs, const 
Array<te::Tensor>& inputs,
                                              const Type& out_type) {
       const auto* param = attrs.as<SoftmaxAttrs>();

Reply via email to