This is an automated email from the ASF dual-hosted git repository.

zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f5f1b91  use nd for accuracy calculation (#9583)
f5f1b91 is described below

commit f5f1b91ff972ad70e9131d3cd1d7408ddddb7684
Author: Sheng Zha <[email protected]>
AuthorDate: Fri Jan 26 22:06:07 2018 -0800

    use nd for accuracy calculation (#9583)
    
    * use nd for accuracy calculation
    
    * check for context
---
 python/mxnet/metric.py | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index 5b0780a..f1cdae2 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -380,23 +380,27 @@ class Accuracy(EvalMetric):
         Parameters
         ----------
         labels : list of `NDArray`
-            The labels of the data.
+            The labels of the data with class indices as values, one per 
sample.
 
         preds : list of `NDArray`
-            Predicted values.
+            Prediction values for samples. Each prediction value can either be 
the class index,
+            or a vector of likelihoods for all classes.
         """
         check_label_shapes(labels, preds)
 
         for label, pred_label in zip(labels, preds):
             if pred_label.shape != label.shape:
                 pred_label = ndarray.argmax(pred_label, axis=self.axis)
-            pred_label = pred_label.asnumpy().astype('int32')
-            label = label.asnumpy().astype('int32')
+            pred_label = pred_label.astype('int32')
+            label = label.astype('int32')
 
             check_label_shapes(label, pred_label)
 
-            self.sum_metric += (pred_label.flat == label.flat).sum()
-            self.num_inst += len(pred_label.flat)
+            if pred_label.context != label.context:
+                pred_label = pred_label.as_in_context(label.context)
+
+            self.sum_metric += (pred_label.flatten() == 
label.flatten()).sum().asscalar()
+            self.num_inst += numpy.prod(pred_label.shape)
 
 
 @register

-- 
To stop receiving notification emails like this one, please contact
[email protected].

Reply via email to