wkcn commented on issue #18643:
URL:
https://github.com/apache/incubator-mxnet/issues/18643#issuecomment-651837537
Hi @DongfeiJi ,
It works for me on MXNet 2.0.
Note that `boolean_mask` doesn't work when mask are all `zero/false`, since
the traditional operator doesn't support zero-size array.
```python
import mxnet as mx
from mxnet import gluon
from mxnet.gluon.loss import Loss, _apply_weighting
class NewTripletLoss(Loss):
def __init__(self, batch_size_per_gpu, margin=1, weight=None,
batch_axis=0, **kwargs):
super(NewTripletLoss, self).__init__(weight, batch_axis, **kwargs)
self.batch_size_per_gpu = batch_size_per_gpu
self.margin = margin
def hybrid_forward(self, F, embeddings, labels, sample_weight=None):
N = self.batch_size_per_gpu
# get distance
xx = F.power(embeddings, 2).sum(1, keepdims=True).tile((1,
self.batch_size_per_gpu))
dist = F.broadcast_add(xx, xx.transpose())
dist = F.broadcast_sub(dist, 2 * F.dot(embeddings,
embeddings.transpose()))
dist = F.clip(dist, 1e-12, 1e12)
# get mask
labels = F.cast(labels, dtype='float32')
labels = labels.expand_dims(1).tile((1, self.batch_size_per_gpu))
is_pos = F.broadcast_equal(labels, labels.transpose())
is_neg = F.broadcast_not_equal(labels, labels.transpose())
# hard example mining
dist_mat = dist.reshape((self.batch_size_per_gpu *
self.batch_size_per_gpu,))
pos_mask = is_pos.reshape((self.batch_size_per_gpu *
self.batch_size_per_gpu,))
dist_ap = F.contrib.boolean_mask(dist_mat,
pos_mask).reshape((self.batch_size_per_gpu, -1))
#dist_ap = F.broadcast_mul(dist_mat,
pos_mask).reshape((self.batch_size_per_gpu, -1))
dist_ap = F.max(dist_ap, axis=1)
neg_mask = is_neg.reshape((self.batch_size_per_gpu *
self.batch_size_per_gpu,))
dist_an = F.contrib.boolean_mask(dist_mat,
neg_mask).reshape((self.batch_size_per_gpu, -1))
#dist_an = F.broadcast_mul(dist_mat,
neg_mask).reshape((self.batch_size_per_gpu, -1))
dist_an = F.min(dist_an, axis=1)
# add margin
margin = F.full(shape=(self.batch_size_per_gpu, 1), val=self.margin)
loss = F.broadcast_add(F.broadcast_sub(dist_ap, dist_an), margin)
loss = F.maximum(loss, F.zeros_like(loss))
# apply weight
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)
block = NewTripletLoss(2)
block.hybridize()
embeddings = mx.nd.array([[1.0, 0.0, 1.0], [1.0, 1.0, 0.0]]).reshape((2,3))
embeddings.attach_grad()
labels = mx.nd.array([0, 1]).reshape((2, ))
with mx.autograd.record():
out = block(embeddings, labels)
out.sum().backward()
print(out)
mx.nd.waitall()
```
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org