This is an automated email from the ASF dual-hosted git repository.
zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new f5f1b91 use nd for accuracy calculation (#9583)
f5f1b91 is described below
commit f5f1b91ff972ad70e9131d3cd1d7408ddddb7684
Author: Sheng Zha <[email protected]>
AuthorDate: Fri Jan 26 22:06:07 2018 -0800
use nd for accuracy calculation (#9583)
* use nd for accuracy calculation
* check for context
---
python/mxnet/metric.py | 16 ++++++++++------
1 file changed, 10 insertions(+), 6 deletions(-)
diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index 5b0780a..f1cdae2 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -380,23 +380,27 @@ class Accuracy(EvalMetric):
Parameters
----------
labels : list of `NDArray`
- The labels of the data.
+ The labels of the data with class indices as values, one per
sample.
preds : list of `NDArray`
- Predicted values.
+ Prediction values for samples. Each prediction value can either be
the class index,
+ or a vector of likelihoods for all classes.
"""
check_label_shapes(labels, preds)
for label, pred_label in zip(labels, preds):
if pred_label.shape != label.shape:
pred_label = ndarray.argmax(pred_label, axis=self.axis)
- pred_label = pred_label.asnumpy().astype('int32')
- label = label.asnumpy().astype('int32')
+ pred_label = pred_label.astype('int32')
+ label = label.astype('int32')
check_label_shapes(label, pred_label)
- self.sum_metric += (pred_label.flat == label.flat).sum()
- self.num_inst += len(pred_label.flat)
+ if pred_label.context != label.context:
+ pred_label = pred_label.as_in_context(label.context)
+
+ self.sum_metric += (pred_label.flatten() ==
label.flatten()).sum().asscalar()
+ self.num_inst += numpy.prod(pred_label.shape)
@register
--
To stop receiving notification emails like this one, please contact
[email protected].