echuraev commented on code in PR #14757:
URL: https://github.com/apache/tvm/pull/14757#discussion_r1184563983
##########
python/tvm/relax/training/loss.py:
##########
@@ -290,3 +290,94 @@ def __call__(
bb.emit_func_output(loss)
return bb.get()[self._loss_name]
+
+class CategoricalCrossEntropyLoss(Loss):
+ r"""CategoricalCrossEntropyLoss.
+ It is a combination of a converting one-hot target vector to a label,
+ a log_softmax computation and a nll_loss.
+
+ Parameters
+ ----------
+ reduction : Literal["mean", "sum", "none"]
+ The reduction method to apply to output. Can be "mean", "sum" or
"none".
+
+ none : no reduction will be applied,
+ mean : the sum of the output will be divided by the batch_size,
+ sum : the output will be summed.
+
+ ignore_index : int
+ Specifies a target value that is ignored and does not contribute to
the input gradient.
+ """
+
+ ignore_index: int
+
+ def __init__(
+ self,
+ reduction: Literal["mean", "sum", "none"] = "mean",
+ ignore_index: int = -100,
+ ) -> None:
+ super().__init__("categorical_cross_entropy_loss", 1, reduction)
+ self.ignore_index = ignore_index
+
+ def __call__(
+ self,
+ predictions: Union[Var, StructInfo],
+ targets: Union[Var, StructInfo],
+ weights: Optional[Union[Var, StructInfo]] = None,
+ ) -> Function:
+ """Get the relax function of CategoricalCrossEntropyLoss. If the
parameters are
+ struct info, it will create corresponding variables.
+
+ Parameters
+ ----------
+ predictions : Union[Var, StructInfo]
+ The predictions of the model in the calculation of loss.
+
+ targets : Union[Var, StructInfo]
+ The ground truth in the calculation of loss.
+
+ weights : Optional[Union[Var, StructInfo]]
+ a manual rescaling weight given to each class. It has to be a
Tensor of size C.
+
+ Returns
+ -------
+ The relax function of CategoricalCrossEntropyLoss with the loss name
as its global symbol.
+ """
+
+ bb = BlockBuilder()
+
+ predictions = _create_param_var(predictions, "predictions")
+ targets = _create_param_var(targets, "targets")
+
+ arg_list = [predictions, targets]
+ if weights:
+ weights = _create_param_var(weights, "weights")
+ arg_list.append(weights)
+
+ if self.ignore_index >= 0:
Review Comment:
Could you please add a clarification commentary why you have two
implementations which are depends on `ignore_index`. I think it might be useful
in the future.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]