# [GitHub] sxjscience commented on issue #9586: mx.metric F1 is using numpy logic

```sxjscience commented on issue #9586: mx.metric F1 is using numpy logic
URL:
https://github.com/apache/incubator-mxnet/issues/9586#issuecomment-365427676

Bring discussion to the correct place. I've implemented an ndarray version
of F1 score when doing the experiments and I've included my `nd_f1` in the
following:

May be useful if we want to accelerate the F1 score computation in the
future. Also, we can take advantage of the fact that the [micro F1 is
equivalent to accuracy for single-label
classification](https://stackoverflow.com/questions/37358496/is-f1-micro-the-same-as-accuracy)
to accelerate the computatoin.

```python
import mxnet.ndarray as nd
from sklearn.metrics import f1_score
import numpy as np
import mxnet as mx
import time

def nd_f1(pred, label, num_class, average="micro"):
"""Evaluate F1 using mx.nd.NDArray

Parameters
----------
pred : nd.NDArray
Shape (num, label_num) or (num,)
label : nd.NDArray
Shape (num, label_num) or (num,)
num_class : int
average : str

Returns
-------
f1 : float
"""
if pred.dtype != np.float32:
pred = pred.astype(np.float32)
label = label.astype(np.float32)
assert num_class > 1
assert pred.ndim == label.ndim
if num_class == 2 and average == "micro":
tp = nd.sum((pred == 1) * (label == 1)).asscalar()
fp = nd.sum((pred == 1) * (label == 0)).asscalar()
fn = nd.sum((pred == 0) * (label == 1)).asscalar()
precision = float(tp) / (tp + fp)
recall = float(tp) / (tp + fn)
f1 = 2 * (precision * recall) / (precision + recall)
else:
assert num_class is not None
pred_onehot = nd.one_hot(indices=pred, depth=num_class)
label_onehot = nd.one_hot(indices=label, depth=num_class)
tp = pred_onehot * label_onehot
fp = pred_onehot * (1 - label_onehot)
fn = (1 - pred_onehot) * label_onehot
if average == "micro":
tp = nd.sum(tp).asscalar()
fp = nd.sum(fp).asscalar()
fn = nd.sum(fn).asscalar()
precision = float(tp) / (tp + fp)
recall = float(tp) / (tp + fn)
f1 = 2 * (precision * recall) / (precision + recall)
elif average == "macro":
if tp.ndim == 3:
tp = nd.sum(tp, axis=(0, 1))
fp = nd.sum(fp, axis=(0, 1))
fn = nd.sum(fn, axis=(0, 1))
else:
tp = nd.sum(tp, axis=0)
fp = nd.sum(fp, axis=0)
fn = nd.sum(fn, axis=0)
precision = nd.mean(tp / (tp + fp)).asscalar()
recall = nd.mean(tp / (tp + fn)).asscalar()
f1 = 2 * (precision * recall) / (precision + recall)
else:
raise NotImplementedError
return f1

for pred_npy, label_npy, num_class\
in [(np.random.randint(0, 50, size=(100000,)),
np.random.randint(0, 50, size=(100000,)),
50),
(np.random.randint(0, 2, size=(10000, 121)),
np.random.randint(0, 2, size=(10000, 121)),
2)]:
# Test F1 score
for average in ['micro', 'macro']:
start = time.time()
for _ in range(5):
f1_npy = f1_score(y_true=label_npy, y_pred=pred_npy,
average=average)
end = time.time()
print("Average=", average, "Npy Time Spent:", end - start)
pred_nd = nd.array(pred_npy, ctx=mx.gpu(), dtype=np.float32)
label_nd = nd.array(label_npy, ctx=mx.gpu(), dtype=np.float32)
nd.waitall()
f1_nd = nd_f1(pred=pred_nd,
label=label_nd,
num_class=num_class,
average=average)
nd.waitall()
start = time.time()
for _ in range(5):
f1_nd = nd_f1(pred=pred_nd,
label=label_nd,
num_class=num_class,
average=average)
nd.waitall()
end = time.time()
print("Average=", average, "NDArray Time Spent:", end - start, 'abs
diff:', abs(f1_nd - f1_npy))
```

Result:
```
Average= micro Npy Time Spent: 0.1795516014099121
Average= micro NDArray Time Spent: 0.033802032470703125 abs diff: 0.0
Average= macro Npy Time Spent: 0.17911505699157715
Average= macro NDArray Time Spent: 0.07393026351928711 abs diff:
4.64383991273e-06
Average= micro Npy Time Spent: 0.6379575729370117
Average= micro NDArray Time Spent: 0.029665708541870117 abs diff: 0.0
Average= macro Npy Time Spent: 0.6377367973327637
Average= macro NDArray Time Spent: 0.034937143325805664 abs diff:
0.000381544355229
``````
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.