ThomasDelteil commented on issue #8833: For gluon, how to define NET (MODEL) 
and LOSS for multi_labels? 
URL: 
https://github.com/apache/incubator-mxnet/issues/8833#issuecomment-409370392
 
 
   @dbsxdbsx 
   
   I modified slightly your code, adding the `expand_dims` and changing the 
`concat` dim so that your output is indeed
   `(batch, label_num, num_class)`
   
   ```python
   class Net(gluon.HybridBlock):
       def __init__(self, **kwargs):
           super(Net, self).__init__(**kwargs)
   
           with self.name_scope():
               self.cov1 = gluon.nn.Conv2D(channels=32, kernel_size=(5, 5))
               self.cov2 = gluon.nn.Conv2D(channels=32, kernel_size=(5, 5))
               self.cov3 = gluon.nn.Conv2D(channels=32, kernel_size=(3, 3))
   
               self.max_pool = gluon.nn.MaxPool2D(pool_size=(2, 2), strides=(1, 
1))
               self.avg_pool = gluon.nn.AvgPool2D(pool_size=(2, 2), strides=(1, 
1))
   
               self.flatten = gluon.nn.Flatten()
               self.dense_256 = gluon.nn.Dense(256)
               self.dense_10 = gluon.nn.Dense(10)
   
       def hybrid_forward(self, F, x):
           x = F.relu(self.max_pool(self.cov1(x)))
           x = F.relu(self.max_pool(self.cov2(x)))
           for _ in range(2):
               x = F.relu(self.avg_pool(self.cov3(x)))
   
           x = self.flatten(x)
           x = F.relu(self.dense_256(x))
           x1 = self.dense_10(x).expand_dims(axis=1)  # for 10 nums
           x2 = self.dense_10(x).expand_dims(axis=1)
           x3 = self.dense_10(x).expand_dims(axis=1)
           x4 = self.dense_10(x).expand_dims(axis=1)
           out1 = F.concat(*[x1, x2, x3, x4], dim=1)
           return out1
   ```
   For example, for `batch=8`, `num_label=4`, `num_classes=10` we indeed get 
the shape we expect:
   ```python
   net = Net()
   net.initialize()
   output = net(mx.nd.ones((8,3,224,224)))
   output.shape
   ```
   ```
   (8, 4, 10)
   ```
   
   For the loss, as indicated by @piiswrong it works indeed. For example let's 
consider the following code:
   
   ```python
   loss = gluon.loss.SoftmaxCrossEntropyLoss()
   
   # We create a hypothetical output of shape (batch, num_label, num_class)
   batch = 2
   num_label = 4
   num_class = 10
   output = mx.nd.zeros(((batch, num_label, num_class)))
   
   # We create a corresponding label of shape (batch, num_label)
   label = [1,2,3,4]
   labels = mx.nd.array([label]*batch)
   
   labels
   ```
   ```
   [[ 1.  2.  3.  4.]
    [ 1.  2.  3.  4.]]
   <NDArray 2x4 @cpu(0)>
   ```
   
   ```python
   # We assign a high value for the output in the right label
   for i in range(batch):
       for j in range(num_label):
           output[i][j][labels[i][j]] = 100
   
   
   output
   ```
   ```
   [[[   0.  100.    0.    0.    0.    0.    0.    0.    0.    0.]
     [   0.    0.  100.    0.    0.    0.    0.    0.    0.    0.]
     [   0.    0.    0.  100.    0.    0.    0.    0.    0.    0.]
     [   0.    0.    0.    0.  100.    0.    0.    0.    0.    0.]]
   
    [[   0.  100.    0.    0.    0.    0.    0.    0.    0.    0.]
     [   0.    0.  100.    0.    0.    0.    0.    0.    0.    0.]
     [   0.    0.    0.  100.    0.    0.    0.    0.    0.    0.]
     [   0.    0.    0.    0.  100.    0.    0.    0.    0.    0.]]]
   <NDArray 2x4x10 @cpu(0)>
   ```
   
   ```python
   loss(output, labels)
   ```
   Indeed the loss is 0 because the output matches the labels!
   ```
   [ 0.  0.]
   <NDArray 2 @cpu(0)>
   ```
   If we modify  one value to be as high as the one corresponding to the 
correct label, we should see the loss increase in the first sample of our batch
   ```python
   output[0][0][0] = 100
   
   loss(output, labels)
   ```
   ```
   [ 0.1732868  0.       ]
   <NDArray 2 @cpu(0)>
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to