Repository: incubator-singa Updated Branches: refs/heads/master 65756e6f6 -> 2224d5f9a
SINGA-388 Develop some RNN layers by calling tiny operations like matmul, addbias - Add unit test case for both vallina rnn and lstm. The unit test cases include gradients shape check as well as value check compared with numerical calculation results. - Add device_check() to valina_rnn and lstm, this function can check the device of inputs and paramerters. If the devices of them are not the same, the funciton can transfer them on a same device. - fix some bugs in test cases and source codes. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/a44a01c0 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/a44a01c0 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/a44a01c0 Branch: refs/heads/master Commit: a44a01c0f291cfca8a688570e3c752b1ef6ec829 Parents: 5dc17b9 Author: xuewanqi <[email protected]> Authored: Tue Aug 14 13:37:07 2018 +0000 Committer: xuewanqi <[email protected]> Committed: Thu Aug 16 11:40:25 2018 +0000 ---------------------------------------------------------------------- python/singa/autograd.py | 86 ++++++++++++------- python/singa/net.py | 0 python/singa/tensor.py | 6 +- test/python/test_layer.py | 2 +- test/python/test_operation.py | 172 ++++++++++++++++++++++++++++++++++++- 5 files changed, 227 insertions(+), 39 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a44a01c0/python/singa/autograd.py ---------------------------------------------------------------------- diff --git a/python/singa/autograd.py b/python/singa/autograd.py old mode 100644 new mode 100755 index b18e08e..7032135 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -55,7 +55,8 @@ def infer_dependency(op): if src_op not in dependency_count: # dependency[src_op] = [Counter() for _ in src_op.y_id2idx] if isinstance(src_op, Dummy): - # only when a Dummy operator needs store grads, its dependency needs to be counted. + # only when a Dummy operator needs store grads, its + # dependency needs to be counted. if src_op.stores_grad: dependency_count[src_op] = 0 queue.append(src_op) @@ -107,9 +108,9 @@ def backward(y, dy=None): if y.stores_grad: #gradients[y] = dy if isinstance(dy, float): - g=np.array(dy) + g = np.array(dy) else: - g=dy + g = dy tg = Tensor(device=g.device(), data=g) yield (y, tg) @@ -139,7 +140,7 @@ def backward(y, dy=None): if isinstance(src_op, Dummy): if not src_op.stores_grad: continue - + y_idx = src_op.y_id2idx[x_id] if src_op not in not_ready: # src_op may have mulitple outputs @@ -153,13 +154,15 @@ def backward(y, dy=None): # add the gradient from another children operation that # uses y_idx'th output of src_op as input arg dxs[y_idx] += dx - + dependency[src_op] -= 1 if y_stores_grad: if dependency[src_op] == 0: # store the gradient for final return, e.g. if x is parameter - # may cause a delay output, as only after src_op is ready then output, not the current outlet of src_op is ready then output. + # may cause a delay output, as only after src_op is ready + # then output, not the current outlet of src_op is ready + # then output. g = not_ready[src_op][y_idx] tg = Tensor(device=g.device(), data=g) yield (y, tg) @@ -167,13 +170,13 @@ def backward(y, dy=None): if src_op.requires_grad is True: if dependency[src_op] == 0: if not isinstance(src_op, Dummy): - #Dummy can be in not_ready list but cannot be in ready list. + # Dummy can be in not_ready list but cannot be in ready + # list. ready.append((src_op, not_ready[src_op])) del not_ready[src_op] del op # delete the operation to free all tensors from this op - class Operation(object): ''' An operation includes the forward and backward function of @@ -800,7 +803,7 @@ class BatchNorm2d(Layer): self.handle.device_id = x.device.id() y = batchnorm_2d(self.handle, x, self.scale, self.bias, - self.running_mean, self.running_var) + self.running_mean, self.running_var) return y @@ -985,6 +988,7 @@ class Tanh(Operation): def tanh(x): return Tanh()(x)[0] + class Sigmoid(Operation): def forward(self, x): @@ -1021,31 +1025,28 @@ class ElemMatmul(Operation): def elemmatmul(x, y): return ElemMatmul()(x, y)[0] + def add_all(*xs): assert len(xs) > 2 - y=add(xs[0],xs[1]) + y = add(xs[0], xs[1]) for x in xs[2:]: - y=add(y, x) + y = add(y, x) return class RNN(Layer): def __init__(self): raise NotImplementedError - def __call__(self, h0, *xs): - batchsize=xs[0].shape[0] - out=[] - h = self.step_forward(xs[0], h0, self.Wx, self.Wh, self.b) - out.append(h) - for x in xs[1:]: - assert x.shape[0] == batchsize - h = self.step_forward(x, h, self.Wx, self.Wh, self.b) - out.append(h) - return out, h + def __call__(self): + raise NotImplementedError + + def step_forward(self): + raise NotImplementedError class Vanilla_RNN(RNN): + def __init__(self, input_size, hidden_size, num_layers=1, nonlinearity='tanh', bias=True, batch_first=False, dropout=0, bidirectional=False): - self.nonlinearity=nonlinearity + self.nonlinearity = nonlinearity Wx_shape = (input_size, hidden_size) self.Wx = Tensor(shape=Wx_shape, requires_grad=True, stores_grad=True) @@ -1055,27 +1056,45 @@ class Vanilla_RNN(RNN): self.Wh = Tensor(shape=Wh_shape, requires_grad=True, stores_grad=True) self.Wh.gaussian(0.0, 1.0) - B_shape=(hidden_size,) + B_shape = (hidden_size,) self.b = Tensor(shape=B_shape, requires_grad=True, stores_grad=True) self.b.set_value(0.0) + #self.params= (self.Wx, self.Wh, self.b) + + def __call__(self, h0, *xs): + inputs=xs+(h0,) + self.device_check(*inputs) + #self.device_check(inputs[0], *self.params) + self.device_check(inputs[0], self.Wx, self.Wh, self.b) + batchsize = xs[0].shape[0] + out = [] + h = self.step_forward(xs[0], h0, self.Wx, self.Wh, self.b) + out.append(h) + for x in xs[1:]: + assert x.shape[0] == batchsize + h = self.step_forward(x, h, self.Wx, self.Wh, self.b) + out.append(h) + return out, h + def step_forward(self, x, h, Wx, Wh, b): - y1=matmul(x, Wx) - y2=matmul(h, Wh) - y=add(y1,y2) - y=add_bias(y,b,axis=0) + y1 = matmul(x, Wx) + y2 = matmul(h, Wh) + y = add(y1, y2) + y = add_bias(y, b, axis=0) if self.nonlinearity == 'tanh': - y=tanh(y) + y = tanh(y) elif self.nonlinearity == 'relu': - y=relu(y) + y = relu(y) else: raise ValueError return y + class LSTM(RNN): def __init__(self, input_size, hidden_size, nonlinearity='tanh', num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False): - self.nonlinearity=nonlinearity + self.nonlinearity = nonlinearity Wx_shape = (input_size, hidden_size) self.Wx = [] @@ -1105,7 +1124,13 @@ class LSTM(RNN): b.set_value(0.0) self.Bh.append(b) + #self.params=self.Wx + self.Wh + self.Bx + self.Bh + def __call__(self, h0, c0, *xs): + inputs=xs+(h0,c0) + self.device_check(*inputs) + #self.device_check(inputs[0], *self.params) + self.device_check(inputs[0], *(self.Wx + self.Wh + self.Bx + self.Bh)) batchsize = xs[0].shape[0] out = [] h, c = self.step_forward( @@ -1154,4 +1179,3 @@ class LSTM(RNN): hout = tanh(cout) hout = elemmatmul(o, hout) return hout, cout - http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a44a01c0/python/singa/net.py ---------------------------------------------------------------------- diff --git a/python/singa/net.py b/python/singa/net.py old mode 100644 new mode 100755 http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a44a01c0/python/singa/tensor.py ---------------------------------------------------------------------- diff --git a/python/singa/tensor.py b/python/singa/tensor.py old mode 100644 new mode 100755 index 441431f..80c9a2e --- a/python/singa/tensor.py +++ b/python/singa/tensor.py @@ -638,7 +638,7 @@ def reshape(tensor, shape): Returns: the new Tensor ''' - return _call_singa_func(singa.Reshape, t.data, s) + return _call_singa_func(singa.Reshape, tensor.data, shape) def transpose(t, axes=None): @@ -1333,8 +1333,8 @@ def tensordot(A, B, axes=2): A = transpose(A, newaxes_a) B = transpose(B, newaxes_b) - at = Reshape(A, newshape_a) - bt = Reshape(B, newshape_b) + at = reshape(A, newshape_a) + bt = reshape(B, newshape_b) res = mult(at, bt) if len(olda + oldb) == 0: http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a44a01c0/test/python/test_layer.py ---------------------------------------------------------------------- diff --git a/test/python/test_layer.py b/test/python/test_layer.py old mode 100644 new mode 100755 index 2c49961..4c859f4 --- a/test/python/test_layer.py +++ b/test/python/test_layer.py @@ -62,7 +62,7 @@ class TestPythonLayer(unittest.TestCase): raw_x = np.arange(9, dtype=np.float32) + 1 x = tensor.from_numpy(raw_x) - x.reshape((1, 1, 3, 3)) + x = x.reshape((1, 1, 3, 3)) w = np.array([1, 1, 0, 0, 0, -1, 0, 1, 0], dtype=np.float32) params[0].copy_from_numpy(w) params[1].set_value(1.0) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/a44a01c0/test/python/test_operation.py ---------------------------------------------------------------------- diff --git a/test/python/test_operation.py b/test/python/test_operation.py index 67018c1..4975d99 100755 --- a/test/python/test_operation.py +++ b/test/python/test_operation.py @@ -6,6 +6,8 @@ from singa import singa_wrap as singa from singa import device from singa import autograd +import numpy as np + autograd.training = True CTensor = singa.Tensor @@ -21,6 +23,31 @@ def _tuple_to_string(t): lt = [str(x) for x in t] return '(' + ', '.join(lt) + ')' +def prepare_inputs_targets_for_rnn_test(): + x_0 = np.random.random((2, 3)).astype(np.float32) + x_1 = np.random.random((2, 3)).astype(np.float32) + x_2 = np.random.random((2, 3)).astype(np.float32) + + h_0 = np.random.random((2, 1)).astype( + np.float32) # (2,1) rather than (2,) + + t_0 = np.random.random((2, 2)).astype(np.float32) + t_1 = np.random.random((2, 2)).astype(np.float32) + t_2 = np.random.random((2, 2)).astype(np.float32) + + x0 = tensor.Tensor(device=gpu_dev, data=x_0) + x1 = tensor.Tensor(device=gpu_dev, data=x_1) + x2 = tensor.Tensor(device=gpu_dev, data=x_2) + + h0 = tensor.Tensor(device=gpu_dev, data=h_0) + + t0 = tensor.Tensor(device=gpu_dev, data=t_0) + t1 = tensor.Tensor(device=gpu_dev, data=t_1) + t2 = tensor.Tensor(device=gpu_dev, data=t_2) + + inputs = [x0, x1, x2] + targets = [t0, t1, t2] + return inputs, targets, h0 class TestPythonOperation(unittest.TestCase): @@ -32,8 +59,8 @@ class TestPythonOperation(unittest.TestCase): def test_conv2d_gpu(self): # (in_channels, out_channels, kernel_size) - conv_0 = autograd.Conv2D(3, 1, 2) - conv_without_bias_0 = autograd.Conv2D(3, 1, 2, bias=False) + conv_0 = autograd.Conv2d(3, 1, 2) + conv_without_bias_0 = autograd.Conv2d(3, 1, 2, bias=False) gpu_input_tensor = tensor.Tensor(shape=(2, 3, 3, 3), device=gpu_dev) gpu_input_tensor.gaussian(0.0, 1.0) @@ -52,8 +79,8 @@ class TestPythonOperation(unittest.TestCase): def test_conv2d_cpu(self): # (in_channels, out_channels, kernel_size) - conv_1 = autograd.Conv2D(3, 1, 2) - conv_without_bias_1 = autograd.Conv2D(3, 1, 2, bias=False) + conv_1 = autograd.Conv2d(3, 1, 2) + conv_without_bias_1 = autograd.Conv2d(3, 1, 2, bias=False) cpu_input_tensor = tensor.Tensor(shape=(2, 3, 3, 3), device=cpu_dev) cpu_input_tensor.gaussian(0.0, 1.0) @@ -87,5 +114,142 @@ class TestPythonOperation(unittest.TestCase): self.check_shape(ds.shape(), (3,)) self.check_shape(db.shape(), (3,)) + def test_vanillaRNN_gpu_tiny_ops(self): + # gradients shape check. + inputs, target, h0 = prepare_inputs_targets_for_rnn_test() + rnn = autograd.Vanilla_RNN(3, 2) + + hs, _ = rnn(h0, *inputs) + + loss = autograd.softmax_cross_entropy(hs[0], target[0]) + for i in range(1, len(hs)): + l = autograd.softmax_cross_entropy(hs[i], target[i]) + loss = autograd.add(loss, l) + # d=autograd.infer_dependency(loss.creator) + # print(d) + for t, dt in autograd.backward(loss): + self.check_shape(t.shape, dt.shape) + + def test_LSTM_gpu_tiny_ops(self): + # gradients shape check. + inputs, target, h0 = prepare_inputs_targets_for_rnn_test() + c_0 = np.random.random((2, 1)).astype(np.float32) + c0 = tensor.Tensor(device=gpu_dev, data=c_0) + + rnn = autograd.LSTM(3, 2) + + hs, _, _ = rnn(h0, c0, *inputs) + loss = autograd.softmax_cross_entropy(hs[0], target[0]) + + for i in range(1, len(hs)): + l = autograd.softmax_cross_entropy(hs[i], target[i]) + loss = autograd.add(loss, l) + # d=autograd.infer_dependency(loss.creator) + # print(d) + for t, dt in autograd.backward(loss): + self.check_shape(t.shape, dt.shape) + + def test_numerical_gradients_check_for_vallina_rnn(self): + inputs, target, h0 = prepare_inputs_targets_for_rnn_test() + + rnn = autograd.Vanilla_RNN(3, 2) + + hs, _ = rnn(h0, *inputs) + + loss1 = autograd.softmax_cross_entropy(hs[0], target[0]) + for i in range(1, len(hs)): + l = autograd.softmax_cross_entropy(hs[i], target[i]) + loss1 = autograd.add(loss1, l) + grads = autograd.gradients(loss1) + + # autograd gradients for dL/dWx[0][0] + d1 = tensor.to_numpy(grads[rnn.Wx])[0][0] + #print('autograd result of dL/dWx[0][0] is ', d1) + + + length = 0.01 + diff = np.array([1, 0, 0, 0, 0, 0]) * length + diff = np.reshape(diff, (3, 2)) + diff = tensor.from_numpy(diff) + diff.to_device(gpu_dev) + + rnn.Wx += diff + hs, _ = rnn(h0, *inputs) + #hs=rnn(h0, x0,x1) + loss2_p = autograd.softmax_cross_entropy(hs[0], target[0]) + for i in range(1, len(hs)): + l = autograd.softmax_cross_entropy(hs[i], target[i]) + loss2_p = autograd.add(loss2_p, l) + + rnn.Wx -= diff + rnn.Wx -= diff + hs, _ = rnn(h0, *inputs) + #hs=rnn(h0, x0,x1) + loss2_n = autograd.softmax_cross_entropy(hs[0], target[0]) + for i in range(1, len(hs)): + l = autograd.softmax_cross_entropy(hs[i], target[i]) + loss2_n = autograd.add(loss2_n, l) + + loss2_p_np = tensor.to_numpy(loss2_p) + loss2_n_np = tensor.to_numpy(loss2_n) + # Numerical gradients for dL/dWx[0][0] + d2 = (loss2_p_np - loss2_n_np) / 2 / length + #print('numerical calculation dL/dWx[0][0] is ', (loss2_p_np-loss2_n_np)/2/length) + + self.assertAlmostEqual(np.sum(d1 - d2), 0., places=3) + + def test_numerical_gradients_check_for_lstm(self): + inputs, target, h0 = prepare_inputs_targets_for_rnn_test() + c_0 = np.random.random((2, 1)).astype(np.float32) + c0 = tensor.Tensor(device=gpu_dev, data=c_0) + + rnn = autograd.LSTM(3, 2) + + hs, _, _ = rnn(h0, c0, *inputs) + + loss1 = autograd.softmax_cross_entropy(hs[0], target[0]) + for i in range(1, len(hs)): + l = autograd.softmax_cross_entropy(hs[i], target[i]) + loss1 = autograd.add(loss1, l) + grads = autograd.gradients(loss1) + + # autograd gradients for dL/dWx[0][0] + d1 = tensor.to_numpy(grads[rnn.Wx[0]])[0][0] + #print('autograd result of dL/dWx[0][0] is ', d1) + + + length = 0.01 + diff = np.array([1, 0, 0, 0, 0, 0]) * length + diff = np.reshape(diff, (3, 2)) + diff = tensor.from_numpy(diff) + diff.to_device(gpu_dev) + + rnn.Wx[0] += diff + hs, _, _ = rnn(h0, c0, *inputs) + #hs=rnn(h0, x0,x1) + loss2_p = autograd.softmax_cross_entropy(hs[0], target[0]) + for i in range(1, len(hs)): + l = autograd.softmax_cross_entropy(hs[i], target[i]) + loss2_p = autograd.add(loss2_p, l) + + rnn.Wx[0] -= diff + rnn.Wx[0] -= diff + hs, _, _ = rnn(h0, c0, *inputs) + #hs=rnn(h0, x0,x1) + loss2_n = autograd.softmax_cross_entropy(hs[0], target[0]) + for i in range(1, len(hs)): + l = autograd.softmax_cross_entropy(hs[i], target[i]) + loss2_n = autograd.add(loss2_n, l) + + loss2_p_np = tensor.to_numpy(loss2_p) + loss2_n_np = tensor.to_numpy(loss2_n) + # Numerical gradients for dL/dWx[0][0] + d2 = (loss2_p_np - loss2_n_np) / 2 / length + #print('numerical calculation dL/dWx[0][0] is ', (loss2_p_np-loss2_n_np)/2/length) + + self.assertAlmostEqual(np.sum(d1 - d2), 0., places=3) + + + if __name__ == '__main__': unittest.main()
