yongwww commented on code in PR #14517:
URL: https://github.com/apache/tvm/pull/14517#discussion_r1159758803


##########
src/relax/op/nn/nn.cc:
##########
@@ -492,5 +492,230 @@ TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits")
     .add_argument("labels", "Tensor", "The labels.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoCrossEntropy);
 
+/* relax.nn.nll_loss */
+TVM_REGISTER_NODE_TYPE(NLLLossAttrs);
+
+Expr nll_loss(Expr predictions, Expr targets, Optional<Expr> weights, String 
reduction,
+              int ignore_index) {
+  ObjectPtr<NLLLossAttrs> attrs = make_object<NLLLossAttrs>();
+
+  ICHECK(reduction == "none" || reduction == "sum" || reduction == "mean")
+      << "The argument reduction of NLLLoss should be one of the following "
+         "values: none, mean, sum. However, the given value is "
+      << reduction;
+
+  attrs->reduction = std::move(reduction);
+  attrs->ignore_index = ignore_index;
+
+  static const Op& op = Op::Get("relax.nn.nll_loss");
+  if (weights.defined()) {
+    return Call(op, {std::move(predictions), std::move(targets), 
std::move(weights.value())},
+                Attrs{attrs}, {});
+  } else {
+    return Call(op, {std::move(predictions), std::move(targets)}, 
Attrs{attrs}, {});
+  }
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.nll_loss").set_body_typed(nll_loss);
+
+StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() < 2 || call->args.size() > 3) {
+    ctx->ReportFatal(Diagnostic::Error(call) << "NLLLoss op should take 2 or 3 
arguments");
+  }
+
+  const auto* pred_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* tgt_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  const TensorStructInfoNode* wgt_sinfo = nullptr;
+  if (call->args.size() == 3) {
+    wgt_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+    if (wgt_sinfo == nullptr) {
+      ctx->ReportFatal(
+          Diagnostic::Error(call)
+          << "NLLLoss requires the argument weights to be Tensor. However, the 
given one is "
+          << call->args[1]->struct_info_->GetTypeKey());

Review Comment:
   -> args[2]



##########
python/tvm/relax/op/nn/nn.py:
##########
@@ -914,6 +914,52 @@ def cross_entropy_with_logits(predictions: Expr, labels: 
Expr) -> Expr:
     return _ffi_api.cross_entropy_with_logits(predictions, labels)  # type: 
ignore
 
 
+def nll_loss(
+    predictions: Expr,
+    targets: Expr,
+    weights: Optional[Expr] = None,
+    reduction: str = "mean",
+    ignore_index: int = -100,
+) -> Expr:
+    """Negative log likelihood loss.
+
+    `output[n, i_1, i_2, ..., i_k] = -p * w`, where
+    - `p = predictions[n, t, i_1, i_2, i_k]`,
+    - `t = targets[n, i_1, i_2, ..., i_k]`,
+    - `w = weights[t] if t != ignore_index else 0`
+
+    result = reduction(output)
+
+    Parameters
+    ----------
+    predictions : relax.Expr
+      The predictions. Should be a `(k+2)-D` Tensor with shape `(N, C, d_1, 
d_2, ..., d_k)` where C
+      is the number of target classes.
+
+    targets : relax.Expr
+      The target value of each prediction. Should be a `(k+1)-D` Tensor with 
shape
+      `(N, d_1, d_2, ..., d_k)`. Must be of int dtype.
+
+    weights : Optional[relax.Expr]
+      The weight of each target value. Should be a `1-D` Tensor with shape 
`(C,)`.
+      If not specified, it is treated as if having all ones.
+
+    reduction : str
+      The reduction method to apply to the output.
+      Possible values are "mean", "sum" and "none".
+
+    ignore_index : int
+      The target value to ignore.
+
+      The computed result.

Review Comment:
   I guess this line should be placed under the "Returns" section.



##########
src/relax/op/nn/nn.cc:
##########
@@ -492,5 +492,230 @@ TVM_REGISTER_OP("relax.nn.cross_entropy_with_logits")
     .add_argument("labels", "Tensor", "The labels.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoCrossEntropy);
 
+/* relax.nn.nll_loss */
+TVM_REGISTER_NODE_TYPE(NLLLossAttrs);
+
+Expr nll_loss(Expr predictions, Expr targets, Optional<Expr> weights, String 
reduction,
+              int ignore_index) {
+  ObjectPtr<NLLLossAttrs> attrs = make_object<NLLLossAttrs>();
+
+  ICHECK(reduction == "none" || reduction == "sum" || reduction == "mean")
+      << "The argument reduction of NLLLoss should be one of the following "
+         "values: none, mean, sum. However, the given value is "
+      << reduction;
+
+  attrs->reduction = std::move(reduction);
+  attrs->ignore_index = ignore_index;
+
+  static const Op& op = Op::Get("relax.nn.nll_loss");
+  if (weights.defined()) {
+    return Call(op, {std::move(predictions), std::move(targets), 
std::move(weights.value())},
+                Attrs{attrs}, {});
+  } else {
+    return Call(op, {std::move(predictions), std::move(targets)}, 
Attrs{attrs}, {});
+  }
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.nll_loss").set_body_typed(nll_loss);
+
+StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() < 2 || call->args.size() > 3) {
+    ctx->ReportFatal(Diagnostic::Error(call) << "NLLLoss op should take 2 or 3 
arguments");
+  }
+
+  const auto* pred_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* tgt_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  const TensorStructInfoNode* wgt_sinfo = nullptr;
+  if (call->args.size() == 3) {
+    wgt_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
+    if (wgt_sinfo == nullptr) {
+      ctx->ReportFatal(
+          Diagnostic::Error(call)
+          << "NLLLoss requires the argument weights to be Tensor. However, the 
given one is "
+          << call->args[1]->struct_info_->GetTypeKey());
+    }
+  }
+
+  if (pred_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "NLLLoss requires the argument preditions to be Tensor. However, 
the given one is "
+        << call->args[0]->struct_info_->GetTypeKey());
+  }
+  if (tgt_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "NLLLoss requires the argument targets to be Tensor. However, the 
given one is "
+        << call->args[2]->struct_info_->GetTypeKey());

Review Comment:
   -> args[1]



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to