This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new c3616ed Add label_from_zero_one argument to LogisticLoss (#9265)
c3616ed is described below
commit c3616edc8059e556d2e38970d4e862088f0e035d
Author: Xingjian Shi <[email protected]>
AuthorDate: Thu Jan 4 11:22:14 2018 -0800
Add label_from_zero_one argument to LogisticLoss (#9265)
* add use_zero_one argument to logisticloss
* add comment
* revise name
* update
* update
---
python/mxnet/gluon/loss.py | 17 +++++++++++++----
1 file changed, 13 insertions(+), 4 deletions(-)
diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py
index 614025c..435230e 100644
--- a/python/mxnet/gluon/loss.py
+++ b/python/mxnet/gluon/loss.py
@@ -619,8 +619,8 @@ class LogisticLoss(Loss):
L = \sum_i \log(1 + \exp(- {pred}_i \cdot {label}_i))
where `pred` is the classifier prediction and `label` is the target tensor
- containing values -1 or 1. `pred` and `label` can have arbitrary shape as
- long as they have the same number of elements.
+ containing values -1 or 1 (0 or 1 if `label_format` is binary).
+ `pred` and `label` can have arbitrary shape as long as they have the same
number of elements.
Parameters
----------
@@ -628,7 +628,10 @@ class LogisticLoss(Loss):
Global scalar weight for loss.
batch_axis : int, default 0
The axis that represents mini-batch.
-
+ label_format : str, default 'signed'
+ Can be either 'signed' or 'binary'. If the label_format is 'signed',
all label values should
+ be either -1 or 1. If the label_format is 'binary', all label values
should be either
+ 0 or 1.
Inputs:
- **pred**: prediction tensor with arbitrary shape.
@@ -643,11 +646,17 @@ class LogisticLoss(Loss):
- **loss**: loss tensor with shape (batch_size,). Dimenions other than
batch_axis are averaged out.
"""
- def __init__(self, weight=None, batch_axis=0, **kwargs):
+ def __init__(self, weight=None, batch_axis=0, label_format='signed',
**kwargs):
super(LogisticLoss, self).__init__(weight, batch_axis, **kwargs)
+ self._label_format = label_format
+ if self._label_format not in ["signed", "binary"]:
+ raise ValueError("label_format can only be signed or binary,
recieved %s."
+ % label_format)
def hybrid_forward(self, F, pred, label, sample_weight=None):
label = _reshape_like(F, label, pred)
+ if self._label_format == 'binary':
+ label = 2 * label - 1 # Transform label to be either -1 or 1
loss = F.log(1.0 + F.exp(-pred * label))
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].