Repository: incubator-singa Updated Branches: refs/heads/master f8cd7e384 -> eec0d52da
Add Mean Square Error loss operation and its unit test case Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/1ecafdec Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/1ecafdec Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/1ecafdec Branch: refs/heads/master Commit: 1ecafdececbacb13cc138a5d0f1b8745f6de2142 Parents: 8aac80e Author: xuewanqi <[email protected]> Authored: Wed Sep 5 04:49:37 2018 +0000 Committer: xuewanqi <[email protected]> Committed: Wed Sep 5 04:49:37 2018 +0000 ---------------------------------------------------------------------- python/singa/autograd.py | 35 +++++++++++++++++++++++++++++++---- test/python/test_operation.py | 16 ++++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/1ecafdec/python/singa/autograd.py ---------------------------------------------------------------------- diff --git a/python/singa/autograd.py b/python/singa/autograd.py index b521126..11ad644 100755 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -520,6 +520,30 @@ def softmax_cross_entropy(x, t): return SoftMaxCrossEntropy(t)(x)[0] +class MeanSquareError(Operation): + + def forward(self, x, t): + self.err = singa.__sub__(x, t) + sqr = singa.Square(self.err) + loss = CTensor((1,), x.device()) + loss.SetFloatValue(singa.SumAsFloat(sqr) / x.shape()[0] / 2) + return loss + + def backward(self, dy=1.0): + dx = self.err + dx *= float(1 / self.err.shape()[0]) + if isinstance(dy, float): + # dtype of dy: float + dx *= dy + return dx, None + elif isinstance(dy, CTensor): + pass # TODO, broadcast elementwise multiply seems not support + + +def mean_square_error(x, t): + return MeanSquareError()(x, t)[0] + + def ctensor2numpy(x): ''' To be used in SoftMax Operation. @@ -1063,7 +1087,9 @@ def add_all(*xs): y = add(y, x) return + class RNN(Layer): + def __init__(self): raise NotImplementedError @@ -1073,6 +1099,7 @@ class RNN(Layer): def step_forward(self): raise NotImplementedError + class Vanilla_RNN(RNN): def __init__(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True, batch_first=False, dropout=0, bidirectional=False): @@ -1090,10 +1117,10 @@ class Vanilla_RNN(RNN): self.b = Tensor(shape=B_shape, requires_grad=True, stores_grad=True) self.b.set_value(0.0) - self.params= (self.Wx, self.Wh, self.b) + self.params = (self.Wx, self.Wh, self.b) def __call__(self, h0, *xs): - inputs=xs+(h0,) + inputs = xs + (h0,) self.device_check(*inputs) #self.device_check(inputs[0], *self.params) self.device_check(inputs[0], self.Wx, self.Wh, self.b) @@ -1154,10 +1181,10 @@ class LSTM(RNN): b.set_value(0.0) self.Bh.append(b) - self.params=self.Wx + self.Wh + self.Bx + self.Bh + self.params = self.Wx + self.Wh + self.Bx + self.Bh def __call__(self, h0, c0, *xs): - inputs=xs+(h0,c0) + inputs = xs + (h0, c0) self.device_check(*inputs) #self.device_check(inputs[0], *self.params) self.device_check(inputs[0], *(self.Wx + self.Wh + self.Bx + self.Bh)) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/1ecafdec/test/python/test_operation.py ---------------------------------------------------------------------- diff --git a/test/python/test_operation.py b/test/python/test_operation.py index 2fdd9fb..9fb6a57 100755 --- a/test/python/test_operation.py +++ b/test/python/test_operation.py @@ -250,5 +250,21 @@ class TestPythonOperation(unittest.TestCase): self.gradients_check(lstm_forward, param, auto_grad) + def test_MeanSquareError(self): + X=np.array([4.3,5.4,3.3,3.6,5.7,6.0]).reshape(3,2).astype(np.float32) + T=np.array([4.4,5.3,3.2,3.7,5.4,6.3]).reshape(3,2).astype(np.float32) + x=tensor.from_numpy(X) + t=tensor.from_numpy(T) + x.to_device(gpu_dev) + t.to_device(gpu_dev) + + loss= autograd.mean_square_error(x,t) + dx=loss.creator.backward()[0] + + loss_np=tensor.to_numpy(loss) + self.assertAlmostEqual(loss_np, 0.0366666, places=4) + self.check_shape(dx.shape(), (3, 2)) + + if __name__ == '__main__': unittest.main()
