zihaolucky commented on a change in pull request #9195: [WIP]NCE loss gluon
URL: https://github.com/apache/incubator-mxnet/pull/9195#discussion_r161791018
 
 

 ##########
 File path: python/mxnet/gluon/loss.py
 ##########
 @@ -696,3 +697,85 @@ def hybrid_forward(self, F, pred, positive, negative):
                      axis=self._batch_axis, exclude=True)
         loss = F.relu(loss + self._margin)
         return _apply_weighting(F, loss, self._weight, None)
+
+
+class NoiseContrastiveEstimationLoss(Loss):
+    r"""Calculates the noise contrastive estimation loss:
+
+    The central idea of NCE is to perform a nonlinear logistic regression to
+    discriminate between the observed data and some artificially generated
+    noise data. So basically it based on one positive and
+
+    .. math::
+
+        pred = class_embedding * activation
+
+        prob = \frac{1}{1 + \exp(-{pred})}
+
+        L = - \sum_i {label}_i * \log({prob}_i) +
+            (1 - {label}_i) * \log(1 - {prob}_i)
+
+    where `pred` is a scalar result from inner product of `class_embedding` 
vector
+    and an `activation` vector, they have the same dimension. For positive 
class,
+    `label` is 1, for sampled negative classes, their label are 0.
+
+    Parameters
+    ----------
+    num_sampled : int
+        Number of sampled noise targets for NCE calculation.
+    num_classes : int
+        Number of classes.
+    noise_distribution : list of float or NDArray
+        Distribution of corresponding classes, for generating noisy targets.
+
+
+    Inputs:
+        - **weight**: The class embeddings. Type: Embedding, Shape 
(num_classes, dim).
+        - **inputs**: Forward activation tensor of the network. Shape 
(batch_size, dim).
+        - **targets**: truth tensor. Shape (batch_size, ).
+
+    Outputs:
+        - **loss**: loss tensor with shape (batch_size,). Dimenions other than
+          batch_axis are averaged out.
+
+    References
+    ----------
+        `Noise-contrastive estimation: A new estimation principle for
+        unnormalized statistical models
+        <proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf>`_
+
+    """
+
+    def __init__(self, num_sampled, num_classes, noise_distribution, **kwargs):
+        super(NoiseContrastiveEstimationLoss, self).__init__(num_sampled, 
num_classes, **kwargs)
+        self.sigmoid_binary_ce_loss = SigmoidBinaryCrossEntropyLoss(**kwargs)
+        self.alias_method_sampler = AliasMethodSampler(num_classes, 
noise_distribution)
+        self._num_sampled = num_sampled
+
+    def hybrid_forward(self, F, weights, inputs, targets, **kwargs):
+        preds, labels = self._compute_sampled_values(F, weights, inputs, 
targets)
+        return self.sigmoid_binary_ce_loss(preds, labels, **kwargs)
+
+    def _compute_sampled_values(self, F, weights, inputs, targets):
+        """Sample negative targets and compute activations"""
+        # (batch_size, dim)
+        targets_embedding = weights(targets)
+        targets_pred = F.broadcast_mul(targets_embedding, inputs)
+        targets_pred = F.sum(data=targets_pred, axis=1)
+
+        # Sample the negative labels.
+        batch_size = inputs.shape[0]
+        sampled_negatives = self.alias_method_sampler.draw(batch_size * 
self._num_sampled)
+
+        # shape:[batch_size, num_sampeld]
+        negatives_embedding = weights(sampled_negatives)
+        negatives_embedding = F.reshape(negatives_embedding, (batch_size, 
self._num_sampled, -1))
+        _inputs = F.reshape(inputs, (batch_size, 1, -1))
+        negatives_pred = F.broadcast_mul(_inputs, negatives_embedding)  # 
shape:(batch_size, num_sampled, embed_size)
+        negatives_pred = F.sum(negatives_pred, axis=2)
 
 Review comment:
   I failed to figure out a way to prettify the `axis`  and `dim`, although we 
have `batch_axis`. @szha 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to