develop and test precision for multi-label metric
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/dde8d14b Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/dde8d14b Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/dde8d14b Branch: refs/heads/master Commit: dde8d14b751a4fe65e373be5076c19ef8f178818 Parents: f6cf8f5 Author: RUAN0007 <[email protected]> Authored: Wed Feb 22 21:12:42 2017 +0800 Committer: RUAN0007 <[email protected]> Committed: Wed Feb 22 21:50:44 2017 +0800 ---------------------------------------------------------------------- python/singa/metric.py | 125 +++++++++++++++++++++++++++++++++++++++- test/python/test_metric.py | 56 ++++++++++++++++++ 2 files changed, 180 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dde8d14b/python/singa/metric.py ---------------------------------------------------------------------- diff --git a/python/singa/metric.py b/python/singa/metric.py index da8213b..2492965 100644 --- a/python/singa/metric.py +++ b/python/singa/metric.py @@ -38,6 +38,7 @@ Example usage:: from . import singa_wrap as singa import tensor +import numpy as np class Metric(object): @@ -78,9 +79,131 @@ class Metric(object): class Accuracy(Metric): - '''Compute the top one accuracy for singel label prediction tasks. + '''Compute the top one accuracy for single label prediction tasks. It calls the C++ functions to do the calculation. ''' def __init__(self): self.swig_metric = singa.Accuracy() + + +class Precision(Metric): + '''Make the top-k labels of max probability as the prediction + + Compute the precision against the groundtruth labels + ''' + def __init__(self, top_k): + self.top_k = top_k + + + + def forward(self, x, y): + '''Compute the precision for each sample. + + Convert tensor to numpy for computation + + Args: + x (Tensor): predictions, one row per sample + y (Tensor): ground truth labels, one row per sample + + Returns: + a tensor of floats, one per sample + ''' + + dev = x.device + x.to_host() + y.to_host() + + x_np = tensor.to_numpy(x) + y_np = tensor.to_numpy(y) + + pred_np = np.argsort(-x_np)[:,0:self.top_k] #Sort in descending order + + tmp_np = np.zeros(pred_np.shape, dtype=np.float32) + + for i in range(pred_np.shape[0]): + tmp_np[i] = y_np[i,pred_np[i]] + + prcs_np = np.average(tmp_np, axis=1) + + prcs = tensor.from_numpy(prcs_np) + + x.to_device(dev) + y.to_device(dev) + prcs.to_device(dev) + + return prcs + + + def evaluate(self, x, y): + '''Compute the averaged precision over all samples. + + Args: + x (Tensor): predictions, one row per sample + y (Tensor): ground truth values, one row per sample + Returns: + a float value for the averaged metric + ''' + + return tensor.average(self.forward(x,y)) + + +class Precision(Metric): + '''Make the top-k labels of max probability as the prediction + + Compute the precision against the groundtruth labels + ''' + def __init__(self, top_k): + self.top_k = top_k + + + + def forward(self, x, y): + '''Compute the precision for each sample. + + Convert tensor to numpy for computation + + Args: + x (Tensor): predictions, one row per sample + y (Tensor): ground truth labels, one row per sample + + Returns: + a tensor of floats, one per sample + ''' + + dev = x.device + x.to_host() + y.to_host() + + x_np = tensor.to_numpy(x) + y_np = tensor.to_numpy(y) + + pred_np = np.argsort(-x_np)[:,0:self.top_k] #Sort in descending order + + tmp_np = np.zeros(pred_np.shape, dtype=np.float32) + + for i in range(pred_np.shape[0]): + tmp_np[i] = y_np[i,pred_np[i]] + + prcs_np = np.average(tmp_np, axis=1) + + prcs = tensor.from_numpy(prcs_np) + + x.to_device(dev) + y.to_device(dev) + prcs.to_device(dev) + + return prcs + + + def evaluate(self, x, y): + '''Compute the averaged precision over all samples. + + Args: + x (Tensor): predictions, one row per sample + y (Tensor): ground truth values, one row per sample + Returns: + a float value for the averaged metric + ''' + + return tensor.average(self.forward(x,y)) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dde8d14b/test/python/test_metric.py ---------------------------------------------------------------------- diff --git a/test/python/test_metric.py b/test/python/test_metric.py new file mode 100644 index 0000000..0d298ae --- /dev/null +++ b/test/python/test_metric.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import numpy as np + +from singa import metric +from singa import tensor + + +class TestPrecision(unittest.TestCase): + def setUp(self): + x_np = np.asarray([[0.7, 0.2, 0.1], + [0.2, 0.4, 0.5], + [0.2,0.4,0.4]], + dtype=np.float32) + + y_np = np.asarray([[1, 0, 1], + [0, 1, 1], + [1, 0, 0]], + dtype=np.int32) + + self.prcs = metric.Precision(top_k=2) + self.x = tensor.from_numpy(x_np) + self.y = tensor.from_numpy(y_np) + + + def test_forward(self): + p = self.prcs.forward(self.x,self.y) + self.assertAlmostEqual(tensor.to_numpy(p)[0], 0.5) + self.assertAlmostEqual(tensor.to_numpy(p)[1], 1) + self.assertAlmostEqual(tensor.to_numpy(p)[2], 0) + + + def test_evaluate(self): + e = self.prcs.evaluate(self.x,self.y) + self.assertAlmostEqual(e, (0.5 + 1 + 0) / 3) + +if __name__ == '__main__': + unittest.main()
