SINGA-245 Float as the first operand can not multiply with a tensor object Add reverse add/sub/mult/div for float-tensor operations. add unit tests in test_tensor.py
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/0ebce1a4 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/0ebce1a4 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/0ebce1a4 Branch: refs/heads/master Commit: 0ebce1a44913dc760f3f6398b34fa45b3dcca5e8 Parents: 76cd806 Author: Wei Wang <[email protected]> Authored: Thu Sep 8 22:48:46 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Thu Sep 8 22:48:46 2016 +0800 ---------------------------------------------------------------------- src/python/singa/loss.py | 3 ++- src/python/singa/tensor.py | 24 ++++++++++++++++++++++++ test/python/test_tensor.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0ebce1a4/src/python/singa/loss.py ---------------------------------------------------------------------- diff --git a/src/python/singa/loss.py b/src/python/singa/loss.py index 8b99ad3..526e4d0 100644 --- a/src/python/singa/loss.py +++ b/src/python/singa/loss.py @@ -95,6 +95,7 @@ class SoftmaxCrossEntropy(Loss): ''' def __init__(self): + super(SoftmaxCrossEntropy, self).__init__() self.swig_loss = singa.SoftmaxCrossEntropy() @@ -105,7 +106,7 @@ class SquaredError(Loss): It is implemented using Python Tensor operations. ''' def __init__(self): - super(Loss, self).__init__() + super(SquareLoss, self).__init__() self.err = None def forward(self, flag, x, y): http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0ebce1a4/src/python/singa/tensor.py ---------------------------------------------------------------------- diff --git a/src/python/singa/tensor.py b/src/python/singa/tensor.py index f6bca43..1024483 100644 --- a/src/python/singa/tensor.py +++ b/src/python/singa/tensor.py @@ -372,6 +372,7 @@ class Tensor(object): ''' python operators (+, -, *, /, <, <=, >, >=) for singa binary operators + https://docs.python.org/2/library/operator.html#mapping-operators-to-functions ''' def __add__(self, rhs): @@ -441,6 +442,29 @@ class Tensor(object): return _call_singa_func(singa.GE_Tf, self.singa_tensor, rhs) + def __radd__(self, lhs): + lhs = float(lhs) + return _call_singa_func(singa.Add_Tf, self.singa_tensor, lhs) + + def __rsub__(self, lhs): + lhs = float(lhs) + ret = _call_singa_func(singa.Sub_Tf, self.singa_tensor, lhs) + ret *= -1 + return ret + + def __rmul__(self, lhs): + lhs = float(lhs) + return _call_singa_func(singa.EltwiseMul_Tf, self.singa_tensor, lhs) + + def __rdiv__(self, lhs): + lhs = float(lhs) + one = Tensor(self.shape, self.device, self.dtype) + one.set_value(1) + one *= lhs + return _call_singa_func(singa.Div_TT, one.singa_tensor,\ + self.singa_tensor) + + ''' python functions for global functions in Tensor.h ''' http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/0ebce1a4/test/python/test_tensor.py ---------------------------------------------------------------------- diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py index 2374adc..a1f220b 100644 --- a/test/python/test_tensor.py +++ b/test/python/test_tensor.py @@ -133,5 +133,34 @@ class TestTensorMethods(unittest.TestCase): self.assertAlmostEqual(tensor.average(x), 1, 3) + def test_radd(self): + x = tensor.Tensor((3,)) + x.set_value(1) + y = 1 + x + self.assertEqual(tensor.average(y), 2.) + + + def test_rsub(self): + x = tensor.Tensor((3,)) + x.set_value(1) + y = 1 - x + self.assertEqual(tensor.average(y), 0.) + + + def test_rmul(self): + x = tensor.Tensor((3,)) + x.set_value(1) + y = 2 * x + self.assertEqual(tensor.average(y), 2.) + + + def test_rdiv(self): + x = tensor.Tensor((3,)) + x.set_value(1) + y = 2 / x + self.assertEqual(tensor.average(y), 2.) + + + if __name__ == '__main__': unittest.main()
