echuraev commented on code in PR #14757:
URL: https://github.com/apache/tvm/pull/14757#discussion_r1184728155
##########
python/tvm/relax/training/loss.py:
##########
@@ -290,3 +290,94 @@ def __call__(
bb.emit_func_output(loss)
return bb.get()[self._loss_name]
+
+class CategoricalCrossEntropyLoss(Loss):
+ r"""CategoricalCrossEntropyLoss.
+ It is a combination of a converting one-hot target vector to a label,
+ a log_softmax computation and a nll_loss.
+
+ Parameters
+ ----------
+ reduction : Literal["mean", "sum", "none"]
+ The reduction method to apply to output. Can be "mean", "sum" or
"none".
+
+ none : no reduction will be applied,
+ mean : the sum of the output will be divided by the batch_size,
+ sum : the output will be summed.
+
+ ignore_index : int
+ Specifies a target value that is ignored and does not contribute to
the input gradient.
+ """
+
+ ignore_index: int
+
+ def __init__(
+ self,
+ reduction: Literal["mean", "sum", "none"] = "mean",
+ ignore_index: int = -100,
+ ) -> None:
+ super().__init__("categorical_cross_entropy_loss", 1, reduction)
+ self.ignore_index = ignore_index
+
+ def __call__(
+ self,
+ predictions: Union[Var, StructInfo],
+ targets: Union[Var, StructInfo],
+ weights: Optional[Union[Var, StructInfo]] = None,
+ ) -> Function:
+ """Get the relax function of CategoricalCrossEntropyLoss. If the
parameters are
+ struct info, it will create corresponding variables.
+
+ Parameters
+ ----------
+ predictions : Union[Var, StructInfo]
+ The predictions of the model in the calculation of loss.
+
+ targets : Union[Var, StructInfo]
Review Comment:
Should we check the type of the targets? As far as I remember, only `int64`
data type is applicable for `targets`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]