piiswrong commented on a change in pull request #7304: gluon bce & ctc losses
URL: https://github.com/apache/incubator-mxnet/pull/7304#discussion_r131712363
 
 

 ##########
 File path: python/mxnet/gluon/loss.py
 ##########
 @@ -142,6 +144,46 @@ def hybrid_forward(self, F, output, label, 
sample_weight=None):
         return F.mean(loss, axis=self._batch_axis, exclude=True)
 
 
+class BinaryCrossEntropyLoss(Loss):
+    r"""The cross-entropy loss for binary classification.
+
+    BCE loss is useful when training logistic regression.
+
+    .. math::
+        loss(o, t) = - 1/n \sum_i (t[i] * log(o[i]) + (1 - t[i]) * log(1 - 
o[i]))
+
+
+    Parameters
+    ----------
+    from_sigmoid : bool, default is `False`
+        Whether the input is from the output of sigmoid. Set this to false 
will make
+        the loss calculate sigmoid and then BCE, which is more numerically 
stable through
+        log-sum-exp trick.
+    weight : float or None
+        Global scalar weight for loss.
+    sample_weight : Symbol or None
+        Per sample weighting. Must be broadcastable to
+        the same shape as loss. For example, if loss has
+        shape (64, 10) and you want to weight each sample
+        in the batch, `sample_weight` should have shape (64, 1).
+    batch_axis : int, default 0
+        The axis that represents mini-batch.
+    """
+    def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, 
**kwargs):
+        super(BinaryCrossEntropyLoss, self).__init__(weight, batch_axis, 
**kwargs)
+        self._from_sigmoid = from_sigmoid
+
+    def hybrid_forward(self, F, output, label, sample_weight=None):
+        label = label.reshape((-1, 1))
 
 Review comment:
   see l1loss for how to reshape
 
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to