BebDong opened a new issue #20067:
URL: https://github.com/apache/incubator-mxnet/issues/20067


   ## Description
   (A clear and concise description of what the feature is.)
   - Cross-entropy loss with online hard example mining that hard to implement 
for multiple-gpu training by high-level Python API
   - Also, to be consistent with GluonCV at 
https://github.com/dmlc/gluon-cv/blob/master/gluoncv/loss.py#L456
   - The following code does not apply to multi-gpu training.
   
   ```python
   from gluoncv import loss as gloss
   
   
   class OHEMCrossEntropyLoss(gloss.SoftmaxCrossEntropyLoss):
       """
       OHEM cross-entropy loss.
       Only support a single GPU.
       Adapted from:
           https://github.com/PaddlePaddle/PaddleSeg/blob/release/v2.0/
           paddleseg/models/losses/ohem_cross_entropy_loss.py
       """
   
       def __init__(self, thresh=0.7, min_kept=10000, num_classes=21, 
height=None, width=None,
                    crop_size=480, sparse_label=True, batch_axis=0, 
ignore_label=-1,
                    size_average=True, **kwargs):
           super(OHEMCrossEntropyLoss, self).__init__(sparse_label, batch_axis, 
ignore_label,
                                                      size_average, **kwargs)
           self._thresh = thresh
           self._min_kept = min_kept
           self._nclass = num_classes
           self._height = height if height is not None else crop_size
           self._width = width if width is not None else crop_size
   
       def hybrid_forward(self, F, logit, label):
           label = F.reshape(label, shape=(-1,))
           valid_mask = (label != self._ignore_label)
           num_valid = F.sum(valid_mask)
           label = label * valid_mask
   
           prob = F.softmax(logit, axis=1)
           prob = F.reshape(F.transpose(prob, axes=(1, 0, 2, 3)), 
shape=(self._nclass, -1))
   
           if self._min_kept < num_valid and num_valid > 0:
               # let the value which ignored greater than 1
               prob = prob + (1 - valid_mask)
               prob = F.pick(prob, label, axis=0, keepdims=False)
   
               threshold = self._thresh
               if self._min_kept > 0:
                   index = F.argsort(prob)
                   threshold_index = index[min(len(index), self._min_kept) - 1]
                   threshold_index = int(threshold_index.asnumpy()[0])
                   if prob[threshold_index] > self._thresh:
                       threshold = prob[threshold_index]
                   kept_mask = (prob < threshold)
                   label = label * kept_mask
                   valid_mask = valid_mask * kept_mask
   
           # make the invalid region as ignore
           label = label + (1 - valid_mask) * self._ignore_label
           label = F.reshape(label, shape=(-1, self._height, self._width))
           return super(OHEMCrossEntropyLoss, self).hybrid_forward(F, logit, 
label)
   ```
   
   ## References
   - A. Shrivastava, A. Gupta, and R. Girshick. Training region-based object 
detectors with online hard example mining. In CVPR, 2016.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to