altanh commented on a change in pull request #8056:
URL: https://github.com/apache/tvm/pull/8056#discussion_r638226321



##########
File path: src/relay/op/nn/nn.cc
##########
@@ -1091,6 +1092,65 @@ Accept logits.
 // Depth to space and space to depth
 TVM_REGISTER_NODE_TYPE(SubPixelAttrs);
 
+// relay.nn.nll_loss
+TVM_REGISTER_NODE_TYPE(NLLLossAttrs);
+
+bool NLLLossRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                const TypeReporter& reporter) {
+  ICHECK_EQ(types.size(), 4) << "NLLLossRel expects 4 types, but " << 
types.size()
+                             << " were provided.";
+  const auto* predictions = types[0].as<TensorTypeNode>();
+  const auto* targets = types[1].as<TensorTypeNode>();
+  const auto* weights = types[2].as<TensorTypeNode>();
+  const NLLLossAttrs* param = attrs.as<NLLLossAttrs>();
+  if (predictions == nullptr || targets == nullptr || weights == nullptr) 
return false;
+  ICHECK(predictions->shape.size() - targets->shape.size() == 1)
+      << "NLLLossRel: predictions should be one dimension larger than targets, 
"
+      << "predictions shape = " << predictions->shape << ", "
+      << "targets shape = " << targets->shape;
+  ICHECK(weights->shape.size() == 1)
+      << "NLLLossRel: weights should be a one dimension Tensor with its length 
"
+      << "the number of classes, but Tensor of dimension " << 
weights->shape.size()
+      << " were provided.";
+  ICHECK(reporter->AssertEQ(predictions->shape[1], weights->shape[0]))
+      << "NLLLossRel: the second dimension of predictions should be the number 
of classes, "
+      << "which is the length of weights, "
+      << "predictions shape = " << predictions->shape << ", "
+      << "weights shape = " << weights->shape;
+  ICHECK(predictions->dtype == weights->dtype && predictions->dtype.is_float())
+      << "NLLLossRel: predictions and weights should be of the same floating 
type.";
+  ICHECK(targets->dtype.is_int()) << "NLLLossRel: targets should be of int 
type.";

Review comment:
       basically, if the error can happen due to user input (e.g. using wrong 
shapes), we should definitely use diagnostics. ICHECK should be reserved only 
for internal compiler checks that should basically never fail unless there's a 
bug somewhere. The diagnostic framework is fairly new so a lot of old code 
still uses ICHECK incorrectly, we just need to slowly go through and update 
them unfortunately




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to