zburning removed a comment on issue #16880: Better to flatten the label array 
in metric.F1()
URL: 
https://github.com/apache/incubator-mxnet/issues/16880#issuecomment-557964023
 
 
   @sxjscience Sorry I didn't make it clear. An example script is:
   
   ```
   import mxnet as mx
   pred = mx.nd.array([[-2.4738965  , 2.7095912 ],
    [ 1.4827207 , -1.6053244 ],
    [ 0.66689086 ,-1.0119148 ],
    [ 0.54501575 ,-0.8739182 ],
    [ 1.7229283  ,-1.80466   ],
    [-2.1540372  , 2.3391898 ],
    [-0.574123   , 0.18217295],
    [-1.5451021  , 1.3035003 ],
    [-2.366786   , 2.5836499 ],
    [-2.469643   , 2.6291811 ]])
   
   label = mx.nd.array([[1],
    [0],
    [0],
    [1],
    [0],
    [1],
    [0],
    [1],
    [1],
    [1]])
   
   print(pred.shape, label.shape) # pred shape: (10, 2), label shape: (10, 1)
   
   metric = mx.metric.F1()
   metric.update([label], [pred])
   print(metric.get()) # ('f1', 0.6)
   
   metric.reset()
   metric.update([label.reshape(-1)], [pred]) # label shape: (10,)
   print(metric.get()) # ('f1', 0.8333333333333334) This one is the correct 
result
   ```
   The current F1() uses the _BinaryClassificationMetrics() class to update the 
stats. So in _BinaryClassificationMetrics.update_binary_stats(), it has:
   ```
    pred = pred.asnumpy()
   label = label.asnumpy().astype('int32')
   pred_label = numpy.argmax(pred, axis=1)
   check_label_shapes(label, pred)
   ```
   The problem is that numpy.argmax(pred, axis=1) returns an array of shape 
(batch, ), the following computing method requires the label be the same size, 
i.e, the label should also be (batch, ). Also the following 
check_label_shapes() actually does nothing because the key argument "shape" is 
set to False by default. So the function can run without error but return a 
wrong result. It is easy to solve but since you mentioned refactoring it to 
support multi-label classification, we may not rely on the 
_BinaryClassificationMetrics() in the future? But anyway I think the current 
setting in _BinaryClassificationMetrics() is not good and actually other 
metrics(e.g. MCC()) will also suffer this problem potentially.
   
   
   
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to