SINGA-218 Implementation for RNN CUDNN version Add an example using the char-rnn model. The trained model (with 2 stacks of lstm) over linux kernel source code could generate source code with some meaning full patterns, e.g., indention, comments, variable definition, assignments.
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/dfc422e5 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/dfc422e5 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/dfc422e5 Branch: refs/heads/dev Commit: dfc422e5b6229de5f598ee3f0226f1a0d082eb16 Parents: 8e0b108 Author: Wei Wang <[email protected]> Authored: Mon Aug 8 17:41:26 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Wed Aug 10 00:47:10 2016 +0800 ---------------------------------------------------------------------- examples/char-rnn/README.md | 31 ++++++ examples/char-rnn/sample.py | 102 ++++++++++++++++++ examples/char-rnn/train.py | 207 +++++++++++++++++++++++++++++++++++++ src/core/tensor/tensor.cc | 3 +- src/io/csv_encoder.cc | 2 +- src/model/layer/cudnn_rnn.cc | 34 +++--- src/model/layer/rnn.cc | 16 +-- src/model/layer/rnn.h | 30 ++++-- src/proto/model.proto | 2 +- src/python/singa/layer.py | 71 ++++++++++++- src/python/swig/model_layer.i | 60 +++++++---- test/singa/test_cudnn_rnn.cc | 34 +++--- 12 files changed, 519 insertions(+), 73 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/examples/char-rnn/README.md ---------------------------------------------------------------------- diff --git a/examples/char-rnn/README.md b/examples/char-rnn/README.md new file mode 100644 index 0000000..c5cdbb8 --- /dev/null +++ b/examples/char-rnn/README.md @@ -0,0 +1,31 @@ +# Train Char-RNN using SINGA + +Recurrent neural networks (RNN) are widely used for modelling sequential data, +e.g., natural language sentences. This example describe how to implement a RNN +application (or model) using SINGA's RNN layers. +We will use the [char-rnn](https://github.com/karpathy/char-rnn) modle as an +example, which trains over setences or +source code, with each character as an input unit. Particularly, we will train +a RNN using GRU over Linux kernel source code. After training, we expect to +generate meaningful code from the model. + + +## Instructions + +* Compile and install SINGA. Currently the RNN implmentation depends on Cudnn V5. + +* Prepare the dataset. Download the [kernel source code](http://cs.stanford.edu/people/karpathy/char-rnn/). +Other plain text files can also be used. + +* Start the training, + + python train.py input_linux.txt + + Some hyper-parameters could be set through command line, + + python train.py -h + + +* Sample characters from the model by providing num of characters and the seed string. + + python sample.py 100 --seed '#include <std' http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/examples/char-rnn/sample.py ---------------------------------------------------------------------- diff --git a/examples/char-rnn/sample.py b/examples/char-rnn/sample.py new file mode 100644 index 0000000..a8fcb73 --- /dev/null +++ b/examples/char-rnn/sample.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +'''Sample characters from the pre-trained model''' +import sys +import os +import cPickle as pickle +import numpy as np +import argparse + +sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python')) +from singa import layer +from singa import tensor +from singa import device +from singa.proto import model_pb2 + + +def sample(model_path, nsamples=100, seed_text='', do_sample=True): + with open(model_path, 'rb') as fd: + d=pickle.load(fd) + rnn_w = tensor.from_numpy(d['rnn_w']) + idx_to_char=d['idx_to_char'] + char_to_idx=d['char_to_idx'] + vocab_size = len(idx_to_char) + dense_w = tensor.from_numpy(d['dense_w']) + dense_b = tensor.from_numpy(d['dense_b']) + hidden_size = d['hidden_size'] + num_stacks = d['num_stacks'] + dropout = d['dropout'] + + cuda = device.create_cuda_gpu() + rnn = layer.LSTM(name='lstm', hidden_size=hidden_size, + num_stacks=num_stacks, dropout=dropout, + input_sample_shape=(len(idx_to_char),)) + rnn.to_device(cuda) + rnn.param_values()[0].copy_data(rnn_w) + dense = layer.Dense('dense', vocab_size, input_sample_shape=(hidden_size,)) + dense.to_device(cuda) + dense.param_values()[0].copy_data(dense_w) + dense.param_values()[1].copy_data(dense_b) + hx = tensor.Tensor((num_stacks, 1, hidden_size), cuda) + cx = tensor.Tensor((num_stacks, 1, hidden_size), cuda) + hx.set_value(0.0) + cx.set_value(0.0) + if len(seed_text) > 0: + for c in seed_text: + x = np.zeros((1, vocab_size), dtype=np.float32) + x[0, char_to_idx[c]] = 1 + tx=tensor.from_numpy(x) + tx.to_device(cuda) + inputs=[tx, hx, cx] + outputs=rnn.forward(model_pb2.kEval, inputs) + y = dense.forward(model_pb2.kEval, outputs[0]) + y = tensor.softmax(y) + hx = outputs[1] + cx = outputs[2] + sys.stdout.write(seed_text) + else: + y = tensor.Tensor((1, vocab_size), cuda) + y.set_value(1.0 / vocab_size) + + for i in range(nsamples): + y.to_host() + prob = tensor.to_numpy(y)[0] + if do_sample: + cur=np.random.choice(vocab_size, 1, p=prob)[0] + else: + cur = np.argmax(prob) + sys.stdout.write(idx_to_char[cur]) + x = np.zeros((1, vocab_size), dtype=np.float32) + x[0, cur] = 1 + tx=tensor.from_numpy(x) + tx.to_device(cuda) + inputs=[tx, hx, cx] + outputs=rnn.forward(model_pb2.kEval, inputs) + y = dense.forward(model_pb2.kEval, outputs[0]) + y = tensor.softmax(y) + hx = outputs[1] + cx = outputs[2] + print '' + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='sample chars from char-rnn') + parser.add_argument('--seed', help='seed text string which warms up the rnn'\ + ' states for sampling', default='') + parser.add_argument('n', type=int, help='num of characters to sample') + args = parser.parse_args() + assert args.n > 0, 'n must > 0' + sample('model.bin', args.n, seed_text=args.seed) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/examples/char-rnn/train.py ---------------------------------------------------------------------- diff --git a/examples/char-rnn/train.py b/examples/char-rnn/train.py new file mode 100644 index 0000000..22fdc82 --- /dev/null +++ b/examples/char-rnn/train.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +'''Train a Char-RNN model using plain text files. +The model is created following https://github.com/karpathy/char-rnn +The train file could be any text file, +e.g., http://cs.stanford.edu/people/karpathy/char-rnn/ +''' +import sys +import os +import cPickle as pickle +import numpy as np +import argparse + +sys.path.append(os.path.join(os.path.dirname(__file__), '../../build/python')) +from singa import layer +from singa import loss +from singa import device +from singa import tensor +from singa import optimizer +from singa import initializer +from singa.proto import core_pb2 +from singa.proto import model_pb2 +from singa import utils + + +class Data(object): + def __init__(self, fpath, batch_size=32, seq_length=100, train_ratio=0.8): + '''Data object for loading a plain text file. + + Args: + fpath, path to the text file. + train_ratio, split the text file into train and test sets, where + train_ratio of the characters are in the train set. + ''' + self.raw_data = open(fpath, 'r').read() # read text file + chars = list(set(self.raw_data)) + self.vocab_size = len(chars) + self.char_to_idx = {ch:i for i, ch in enumerate(chars)} + self.idx_to_char = {i:ch for i, ch in enumerate(chars)} + data = [self.char_to_idx[c] for c in self.raw_data] + # seq_length + 1 for the data + label + nsamples = len(data) / (1 + seq_length) + data = data[0:nsamples * (1 + seq_length)] + data = np.asarray(data, dtype=np.int32) + data = np.reshape(data, (-1, seq_length + 1)) + # shuffle all sequences + np.random.shuffle(data) + self.train_dat = data[0:int(data.shape[0]*train_ratio)] + self.num_train_batch = self.train_dat.shape[0] / batch_size + self.val_dat = data[self.train_dat.shape[0]:] + self.num_test_batch = self.val_dat.shape[0] / batch_size + print 'train dat', self.train_dat.shape + print 'val dat', self.val_dat.shape + + +def numpy2tensors(npx, npy, dev): + '''batch, seq, dim -- > seq, batch, dim''' + tmpx=np.swapaxes(npx, 0, 1) + tmpy=np.swapaxes(npy, 0, 1) + inputs=[] + labels=[] + for t in range(tmpx.shape[0]): + x = tensor.from_numpy(tmpx[t]) + y = tensor.from_numpy(tmpy[t]) + x.to_device(dev) + y.to_device(dev) + inputs.append(x) + labels.append(y) + return inputs, labels + + +def convert(batch, batch_size, seq_length, vocab_size, dev): + '''convert a batch of data into a sequence of input tensors''' + y = batch[:, 1:] + x1 = batch[:, :seq_length] + x = np.zeros((batch_size, seq_length, vocab_size), dtype=np.float32) + for b in range(batch_size): + for t in range(seq_length): + c = x1[b, t] + x[b, t, c] = 1 + return numpy2tensors(x, y, dev) + + +def get_lr(epoch): + return 0.001 / float(1 << (epoch / 50)) + + +def train(data, max_epoch, hidden_size =100, seq_length=100, batch_size=16, + num_stacks=1, lr=0.001, dropout = 0.5, model_path='model.bin'): + # SGD with L2 gradient normalization + opt = optimizer.SGD(constraint=optimizer.L2Constraint(5)) + cuda = device.create_cuda_gpu() + rnn = layer.LSTM(name='lstm', hidden_size=hidden_size, num_stacks=num_stacks, + dropout=dropout, input_sample_shape=(data.vocab_size,)) + rnn.to_device(cuda) + print 'created rnn' + rnn_w = rnn.param_values()[0] + initializer.uniform(rnn_w, -0.08, 0.08) # init all rnn parameters + print 'rnn weight l1 = %f' % (rnn_w.l1()) + dense = layer.Dense('dense', data.vocab_size, input_sample_shape=(hidden_size,)) + dense.to_device(cuda) + dense_w = dense.param_values()[0] + dense_b = dense.param_values()[1] + print 'dense w ', dense_w.shape + print 'dense b ', dense_b.shape + initializer.xavier(dense_w) # init weight matrix using Xavier + print 'dense weight l1 = %f' % (dense_w.l1()) + dense_b.set_value(0.0) + print 'dense b l1 = %f' % (dense_b.l1()) + + g_dense_w = tensor.Tensor(dense_w.shape, cuda) + g_dense_b = tensor.Tensor(dense_b.shape, cuda) + + lossfun = loss.SoftmaxCrossEntropy(); + for epoch in range(max_epoch): + train_loss = 0 + for b in range(data.num_train_batch): + batch = data.train_dat[b * batch_size: (b + 1) * batch_size] + inputs, labels = convert(batch, batch_size, seq_length, + data.vocab_size, cuda) + inputs.append(tensor.Tensor()) + inputs.append(tensor.Tensor()) + + outputs = rnn.forward(model_pb2.kTrain, inputs)[0:-2] + grads=[] + batch_loss = 0 + g_dense_w.set_value(0.0) + g_dense_b.set_value(0.0) + for output, label in zip(outputs, labels): + act = dense.forward(model_pb2.kTrain, output) + lvalue = lossfun.forward(model_pb2.kTrain, act, label) + batch_loss += lvalue.l1() + grad = lossfun.backward() + grad, gwb = dense.backward(model_pb2.kTrain, grad) + grads.append(grad) + g_dense_w += gwb[0] + g_dense_b += gwb[1] + #print output.l1(), act.l1() + utils.update_progress(b * 1.0 / data.num_train_batch, + 'training loss = %f' % (batch_loss / seq_length)) + train_loss += batch_loss + + grads.append(tensor.Tensor()) + grads.append(tensor.Tensor()) + g_rnn_w=rnn.backward(model_pb2.kTrain, grads)[1][0] + dense_w, dense_b = dense.param_values() + opt.apply_with_lr(epoch, get_lr(epoch), g_rnn_w, rnn_w, 'rnnw') + opt.apply_with_lr(epoch, get_lr(epoch), g_dense_w, dense_w, 'dense_w') + opt.apply_with_lr(epoch, get_lr(epoch), g_dense_b, dense_b, 'dense_b') + print '\nEpoch %d, train loss is %f' % (epoch, + train_loss / data.num_train_batch / seq_length) + eval_loss = 0 + for b in range(data.num_test_batch): + batch = data.val_dat[b * batch_size: (b + 1) * batch_size] + inputs, labels = convert(batch, batch_size, seq_length, + data.vocab_size, cuda) + inputs.append(tensor.Tensor()) + inputs.append(tensor.Tensor()) + outputs = rnn.forward(model_pb2.kEval, inputs)[0:-2] + for output, label in zip(outputs, labels): + output = dense.forward(model_pb2.kEval, output) + eval_loss += lossfun.forward(model_pb2.kEval, output, label).l1() + print 'Epoch %d, evaluation loss is %f' % (epoch, + eval_loss / data.num_test_batch / seq_length) + + # checkpoint the file model + with open(model_path, 'wb') as fd: + print 'saving model to %s' % model_path + d={} + for name, w in zip(['rnn_w', 'dense_w', 'dense_b'], [rnn_w, dense_w, dense_b]): + w.to_host() + d[name]=tensor.to_numpy(w) + d['idx_to_char']=data.idx_to_char + d['char_to_idx']=data.char_to_idx + d['hidden_size']=hidden_size + d['num_stacks']=num_stacks + d['dropout']=dropout + + pickle.dump(d, fd) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Train multi-stack LSTM for '\ + 'modeling character sequence from plain text files') + parser.add_argument('data', type=string, help='training file') + parser.add_argument('-b', type=int, default=32, help='batch_size') + parser.add_argument('-l', type=int, default=64, help='sequence length') + parser.add_argument('-d', type=int, default=128, help='hidden size') + parser.add_argument('-s', type=int, default=2, help='num of stacks') + parser.add_argument('-m', type=int, default=50, help='max num of epoch') + args = parser.parse_args() + data = Data(args.data, batch_size=args.b, seq_length=args.l) + train(data, args.m, hidden_size=args.d, num_stacks=args.s, + seq_length=args.l, batch_size=args.b) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index bd3bc70..d2fec53 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -299,7 +299,8 @@ Tensor &Tensor::operator=(const Tensor &in) { shape_ = in.shape_; device_ = in.device_; block_ = in.block(); - block_->IncRefCount(); + if (block_ != nullptr) + block_->IncRefCount(); return *this; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/io/csv_encoder.cc ---------------------------------------------------------------------- diff --git a/src/io/csv_encoder.cc b/src/io/csv_encoder.cc index 1b797a9..6089ab5 100644 --- a/src/io/csv_encoder.cc +++ b/src/io/csv_encoder.cc @@ -22,7 +22,7 @@ namespace singa { std::string CSVEncoder::Encode(vector<Tensor>& data) { - CHECK_GE(data.size(), 1); + CHECK_GE(data.size(), 1u); size_t size = data[0].Size(); const float* value = data[0].data<float>(); std::string des = ""; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/model/layer/cudnn_rnn.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc index 242a342..896c1e9 100644 --- a/src/model/layer/cudnn_rnn.cc +++ b/src/model/layer/cudnn_rnn.cc @@ -17,6 +17,7 @@ */ #include "./cudnn_rnn.h" #ifdef USE_CUDNN +#if CUDNN_VERSION_MAJOR >= 5 #include <cudnn.h> #include <chrono> #include "./cudnn_utils.h" @@ -92,7 +93,7 @@ void CudnnRNN::UpdateIODescriptors(size_t len, const vector<Tensor> &inputs) { } for (size_t i = 0; i < len; i++) { - CHECK_EQ(inputs[i].shape(1), input_dim_); + CHECK_EQ(inputs[i].shape(1), input_size_); if (inputs[i].shape(0) != batch_size_ || reset) { int d[3] = {1, 1, 1}, s[3] = {1, 1, 1}; d[0] = static_cast<int>(inputs[i].shape(0)); @@ -104,7 +105,7 @@ void CudnnRNN::UpdateIODescriptors(size_t len, const vector<Tensor> &inputs) { CUDNN_CHECK(cudnnSetTensorNdDescriptor(dx_descs_[i], dtype_, 3, d, s)); d[0] = static_cast<int>(inputs[i].shape(0)); - d[1] = static_cast<int>(hidden_dim_ * num_directions_); + d[1] = static_cast<int>(hidden_size_ * num_directions_); s[0] = d[1] * d[2]; s[1] = d[2]; CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_descs_[i], dtype_, 3, d, s)); @@ -121,7 +122,7 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) { CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size)); dropout_state_ = Tensor(Shape{state_size}, dev, kChar); CUDNN_CHECK(cudnnSetDropoutDescriptor( - dropout_desc_, ctx->cudnn_handle, dropout_, + dropout_desc_, ctx->cudnn_handle, 1 - dropout_, // keep probability dropout_state_.block()->mutable_data(), state_size, seed_)); CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_)); @@ -146,7 +147,7 @@ void CudnnRNN::SetRNNDescriptor(shared_ptr<Device> dev) { rnn_mode = CUDNN_LSTM; else if (rnn_mode_ == "gru") rnn_mode = CUDNN_GRU; - CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_dim_, num_stacks_, + CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_size_, num_stacks_, dropout_desc_, input_mode, direction, rnn_mode, dtype_)); @@ -176,7 +177,7 @@ void CudnnRNN::ResetHiddenAndCellDescriptors(size_t batch_size) { int dim[3] = {1, 1, 1}; dim[0] = static_cast<int>(num_stacks_ * num_directions_); dim[1] = static_cast<int>(batch_size); - dim[2] = static_cast<int>(hidden_dim_); + dim[2] = static_cast<int>(hidden_size_); int stride[3] = {1, 1, 1}; stride[0] = dim[1] * dim[2]; stride[1] = dim[2]; @@ -238,7 +239,7 @@ vector<Tensor> CudnnRNN::SplitOutput(size_t num, size_t dim, const Tensor output) { vector<Tensor> outputs; if (num == 1) { - outputs.push_back(output); + outputs.push_back(Reshape(output, Shape{in.at(0).shape(0), dim})); } else { for (size_t i = 0, offset = 0; offset < output.Size(); i++) { Shape s{in.at(i).shape(0), dim}; @@ -261,7 +262,7 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) { CHECK_GT(inputs.size(), 1u + has_cell_); size_t num_x = inputs.size() - has_cell_ - 1; Tensor input = MergeInputs(num_x, inputs); - LOG(INFO) << "input size " << input.Size() << " value " << input.L1(); + // LOG(INFO) << "input size " << input.Size() << " value " << input.L1(); if (rnn_desc_ != nullptr) CHECK_EQ(dtype_, GetCudnnDataType(dtype)) @@ -273,11 +274,11 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) { UpdateStates(num_x, inputs); // CheckFowardShapes(); - Shape outshape{input.Size() * hidden_dim_ / input_dim_ * num_directions_}; + Shape outshape{input.Size() * hidden_size_ / input_size_ * num_directions_}; Tensor output(outshape, dev, dtype); - LOG(INFO) << "output size " << output.Size(); + // LOG(INFO) << "output size " << output.Size(); Tensor hx = inputs.at(num_x); - Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_dim_}; + Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_size_}; Tensor hy(state_shape, dev, dtype); Tensor cy, cx; if (has_cell_) { @@ -285,8 +286,8 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) { cy.ResetLike(hy); } - LOG(INFO) << "hidden size " << hy.Size(); - LOG(INFO) << "weight size " << weight_.Size() << " value " << weight_.L1(); + // LOG(INFO) << "hidden size " << hy.Size(); + // LOG(INFO) << "weight size " << weight_.Size() << " value " << weight_.L1(); Block *inb = input.block(), *outb = output.block(), *wb = this->weight_.block(), *hxb = hx.block(), *cxb = cx.block(), *hyb = hy.block(), *cyb = cy.block(), @@ -336,7 +337,7 @@ const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor> &inputs) { }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace}); } auto outputs = - SplitOutput(num_x, hidden_dim_ * num_directions_, inputs, output); + SplitOutput(num_x, hidden_size_ * num_directions_, inputs, output); outputs.push_back(hy); if (has_cell_) outputs.push_back(cy); return outputs; @@ -368,10 +369,10 @@ const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward( if (has_cell_) dcy = grads.at(num_dy + 1); - Shape xshape{y.Size() * input_dim_ / hidden_dim_ / num_directions_}; + Shape xshape{y.Size() * input_size_ / hidden_size_ / num_directions_}; Tensor dx(xshape, dev, dtype); Tensor dw(weight_.shape(), dev, dtype); - Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_dim_}; + Shape state_shape{num_stacks_ * num_directions_, batch_size_, hidden_size_}; Tensor dhx(state_shape, dev, dtype); Tensor dcx; if (has_cell_) @@ -419,7 +420,7 @@ const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward( {dxb, dwb, dhxb, dcxb, wspace, rspace}); vector <Tensor> param_grad{dw}; - auto data_grads = SplitOutput(num_dy, input_dim_, grads, dx); + auto data_grads = SplitOutput(num_dy, input_size_, grads, dx); data_grads.push_back(dhx); if (has_cell_) data_grads.push_back(dcx); @@ -427,4 +428,5 @@ const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward( } } // namespace singa +#endif // CUDNN_VERSION_MAJOR >= 5 #endif // USE_CUDNN http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/model/layer/rnn.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/rnn.cc b/src/model/layer/rnn.cc index 6b831a7..424c20b 100644 --- a/src/model/layer/rnn.cc +++ b/src/model/layer/rnn.cc @@ -27,13 +27,13 @@ void RNN::Setup(const Shape& in_sample, const LayerConf &conf) { Layer::Setup(in_sample, conf); RNNConf rnn_conf = conf.rnn_conf(); - hidden_dim_ = rnn_conf.hidden_dim(); - CHECK_GT(hidden_dim_, 0u); + hidden_size_ = rnn_conf.hidden_size(); + CHECK_GT(hidden_size_, 0u); num_stacks_ = rnn_conf.num_stacks(); CHECK_GT(num_stacks_, 0u); - input_dim_ = Product(in_sample); - CHECK_GT(input_dim_, 0u); - dropout_ = rnn_conf.dropout(); + input_size_ = Product(in_sample); + CHECK_GT(input_size_, 0u); + dropout_ = rnn_conf.dropout(); // drop probability CHECK_GE(dropout_, 0); input_mode_ = ToLowerCase(rnn_conf.input_mode()); @@ -71,9 +71,9 @@ void RNN::Setup(const Shape& in_sample, const LayerConf &conf) { size_t weight_size = 0; for (size_t i = 0; i < num_stacks_; i++) { - size_t dim = hidden_dim_ * (in_sample[0] + hidden_dim_ + 2); + size_t dim = hidden_size_ * (in_sample[0] + hidden_size_ + 2); if (i > 0) - dim = hidden_dim_ * (hidden_dim_ + hidden_dim_ + 2); + dim = hidden_size_ * (hidden_size_ + hidden_size_ + 2); weight_size += mult * dim; } weight_.Reshape(Shape{weight_size}); @@ -81,6 +81,7 @@ void RNN::Setup(const Shape& in_sample, const LayerConf &conf) { const vector<Tensor> RNN::Forward(int flag, const vector<Tensor>& inputs) { vector<Tensor> data_output; + LOG(FATAL) << "CPU RNN is not implemented!"; return data_output; } @@ -88,6 +89,7 @@ const std::pair<vector<Tensor>, vector<Tensor>> RNN::Backward(int flag, const vector<Tensor>& grads) { vector<Tensor> param_grad; vector<Tensor> data_grad; + LOG(FATAL) << "CPU RNN is not implemented!"; return std::make_pair(data_grad, param_grad); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/model/layer/rnn.h ---------------------------------------------------------------------- diff --git a/src/model/layer/rnn.h b/src/model/layer/rnn.h index 3750021..1b5dad7 100644 --- a/src/model/layer/rnn.h +++ b/src/model/layer/rnn.h @@ -37,20 +37,32 @@ class RNN : public Layer { /// \copydoc Layer::layer_type() const std::string layer_type() const override { return "RNN"; } - /// \copydoc Layer::Setup(const LayerConf&); - void Setup(const vector<size_t>& in_shape, const LayerConf& conf) override; + /// Setup the RNN layer. + /// in_shape is the shape of a single training instance from one timestep, + void Setup(const Shape& in_shape, const LayerConf& conf) override; - /// \copydoc Layer::Forward(int flag, const vector<Tensor>&) + /// The inputs vector includes <x1, ... xn, hx, cx> where xi is the input + /// tensor at the i-th time step. hx is used to initialize the hidden tensor, + /// which could be a dummy tensor (like Tensor hx;). cx is used to initialize + /// the cell tensor, which could be a dummy tensor( like Tensor cx;). For + /// dummy tensors, 0's would be used during computation. + /// cx is missing for gru/relu/tanh RNNs, and is valid for lstm. + /// The dim order of xi is <batch, feature>, and the batchsize of xi must be + /// >= that of x(i+1). + /// The output vector includes <y1, ... yn, hy, cy> where yi is the output + /// tensor at the i-th time step. hy is the final hidden tensor, cy is the + /// final cell tensor. cy is missing for gru/relu/tanh RNNs and is valid for + /// lstm. const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) override; - /// \copydoc Layer::Backward(int, const vector<Tensor>&); + /// The grads vector includes <dy1, dy2, ... dyn, dhy, dcy>, the symbols are + /// similar to those for Forward. dcy is missing for gru/relu/tanh RNNs and is + /// valid for lstm. + /// The first vector of the output includes <dx1, dx2, ... dxn, dhx, dcx>. + /// The second vector of the output includes the gradients of all parameters. const std::pair<vector<Tensor>, vector<Tensor>> Backward( int flag, const vector<Tensor>& grads) override; - void set_weight(Tensor w) { - weight_.ResetLike(w); - weight_.CopyData(w); - } const vector<Tensor> param_values() override { return vector<Tensor>{weight_}; } @@ -73,7 +85,7 @@ class RNN : public Layer { std::stack<Tensor> buf_; bool has_cell_ = false; size_t num_directions_ = 1; - size_t input_dim_ = 0, hidden_dim_ = 0, num_stacks_ = 0, seq_length_ = 0; + size_t input_size_ = 0, hidden_size_ = 0, num_stacks_ = 0, seq_length_ = 0; size_t batch_size_ = 0; size_t seed_ = 0x1234567; float dropout_ = 0.0f; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/proto/model.proto ---------------------------------------------------------------------- diff --git a/src/proto/model.proto b/src/proto/model.proto index 31ebfc3..6923820 100644 --- a/src/proto/model.proto +++ b/src/proto/model.proto @@ -393,7 +393,7 @@ message ConvolutionConf { } message RNNConf { - optional uint32 hidden_dim = 1; // The number of hiddensize + optional uint32 hidden_size = 1; // The hidden feature size optional uint32 num_stacks = 2; // The number of stacked RNN layers optional float dropout = 3 [default = 0]; optional bool remember_state = 4 [default = false]; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/python/singa/layer.py ---------------------------------------------------------------------- diff --git a/src/python/singa/layer.py b/src/python/singa/layer.py index a443e1a..a87eb10 100644 --- a/src/python/singa/layer.py +++ b/src/python/singa/layer.py @@ -403,7 +403,7 @@ class Dense(Layer): if W_specs is None: W_specs = {'init': 'xavier'} if b_specs is None: - b_specs = {'init': 'constant'} + b_specs = {'init': 'constant', 'value': 0} if 'name' not in W_specs: W_specs['name'] = name + '_weight' if 'name' not in b_specs: @@ -502,6 +502,71 @@ class Flatten(Layer): self.setup(input_sample_shape) +class RNN(Layer): + def __init__(self, name, hidden_size, rnn_mode='lstm', engine='cudnn', + dropout=0.0, num_stacks=1, input_mode='linear', bidirectional=False, + param_specs=None, input_sample_shape=None): + super(RNN, self).__init__(name) + conf = self.conf.rnn_conf + assert hidden_size > 0, 'Hidden feature size must > 0' + conf.hidden_size = hidden_size + assert rnn_mode in Set(['lstm', 'gru', 'tanh', 'relu']), \ + 'rnn mode %s is not available' %s (rnn_mode) + conf.rnn_mode = rnn_mode + conf.num_stacks = num_stacks + conf.dropout = dropout + conf.input_mode = input_mode + conf.direction = 'unidirectional' + if bidirectional: + conf.direction = 'bidirectional' + _check_engine(engine, ['cudnn']) + if param_specs is None: + param_specs = {'name': name + '-weight', + 'init': 'uniform', 'low':0, 'high':1}; + self.conf.param.extend([_construct_param_specs_from_dict(param_specs)]) + self.param_specs.append(_construct_param_specs_from_dict(param_specs)) + + self.layer = singa_wrap.CudnnRNN() + if input_sample_shape is not None: + self.setup(input_sample_shape) + + def forward(self, flag, inputs): + assert self.has_setup, 'Must call setup() before forward()' + assert len(inputs) > 1, 'The input to RNN must include at '\ + 'least one input tensor '\ + 'and one hidden state tensor (could be a dummy tensor)' + tensors = [] + for t in inputs: + assert isinstance(t, tensor.Tensor), 'input must be py Tensor %s' % (type(t)) + tensors.append(t.singa_tensor) + y = self.layer.Forward(flag, tensors) + return tensor.from_raw_tensors(y) + + def backward(self, flag, grad): + tensors = [] + for t in grad: + assert isinstance(t, tensor.Tensor), 'grad must be py Tensor' + tensors.append(t.singa_tensor) + ret = self.layer.Backward(flag, tensors) + return tensor.from_raw_tensors(ret[0]), tensor.from_raw_tensors(ret[1]) + +class LSTM(RNN): + def __init__(self, name, hidden_size, engine='cudnn', + dropout=0.0, num_stacks=1, input_mode='linear', bidirectional=False, + param_specs=None, input_sample_shape=None): + super(LSTM, self).__init__(name, hidden_size, 'lstm', engine, dropout, + num_stacks, input_mode, bidirectional, param_specs, + input_sample_shape) + +class GRU(RNN): + def __init__(self, name, hidden_size, engine='cudnn', + dropout=0.0, num_stacks=1, input_mode='linear', bidirectional=False, + param_specs=None, input_sample_shape=None): + super(GRU, self).__init__(name, hidden_size, 'gru', engine, dropout, + num_stacks, input_mode, bidirectional, param_specs, + input_sample_shape) + + def _check_engine(engine, allowed_engines): assert engine.lower() in Set(allowed_engines), \ '%s is not a supported engine. Pls use one of %s' % \ @@ -585,8 +650,8 @@ def _construct_param_specs_from_dict(specs): if specs['init'].lower() == 'uniform': assert 'low' in specs and 'high' in specs, \ 'low and high are required for "uniform" init method' - filler.low = specs['low'] - filler.high = specs['high'] + filler.min = specs['low'] + filler.max = specs['high'] elif specs['init'].lower() == 'gaussian': assert 'mean' in specs and 'std' in specs, \ 'std and mean are required for "gaussian" init method' http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/src/python/swig/model_layer.i ---------------------------------------------------------------------- diff --git a/src/python/swig/model_layer.i b/src/python/swig/model_layer.i index 873ebc9..9d39301 100644 --- a/src/python/swig/model_layer.i +++ b/src/python/swig/model_layer.i @@ -30,6 +30,8 @@ %{ #include "singa/model/layer.h" +#include "../src/model/layer/rnn.h" +#include "../src/model/layer/cudnn_rnn.h" #include "singa/core/tensor.h" #include "singa/proto/model.pb.h" using singa::Tensor; @@ -40,6 +42,8 @@ using singa::LayerConf; %} %shared_ptr(singa::Layer) +%shared_ptr(singa::RNN) +%shared_ptr(singa::CudnnRNN) namespace std { %template(strVector) vector<string>; @@ -52,26 +56,44 @@ namespace std { namespace singa { - class Layer { - public: - Layer(); +class Layer { + public: + Layer(); // virtual void Setup(const std::vector<vector<size_t>>&, const string&); - virtual void Setup(const std::vector<size_t>& in_sample_shape, - const std::string& proto_str); - const std::vector<Tensor> param_values(); - virtual const std::vector<size_t> GetOutputSampleShape() const; - virtual void ToDevice(std::shared_ptr<Device> device); - virtual void AsType(DataType dtype); - virtual const Tensor Forward(int flag, const Tensor& input); - virtual const std::vector<Tensor> Forward( - int flag, const std::vector<Tensor>& inputs); - virtual const std::pair<Tensor, std::vector<Tensor>> Backward( - int flag, const Tensor& grad); - virtual const std::pair<std::vector<Tensor>, std::vector<Tensor>> - Backward(int flag, const vector<Tensor>& grads); + void Setup(const std::vector<size_t>& in_sample_shape, + const std::string& proto_str); + virtual const std::vector<Tensor> param_values(); + virtual const std::vector<size_t> GetOutputSampleShape() const; + virtual void ToDevice(std::shared_ptr<Device> device); + virtual void AsType(DataType dtype); + virtual const Tensor Forward(int flag, const Tensor& input); + virtual const std::vector<Tensor> Forward( + int flag, const std::vector<Tensor>& inputs); + virtual const std::pair<Tensor, std::vector<Tensor>> Backward( + int flag, const Tensor& grad); + virtual const std::pair<std::vector<Tensor>, std::vector<Tensor>> + Backward(int flag, const vector<Tensor>& grads); +}; + +std::shared_ptr<Layer> CreateLayer(const std::string& type); +const std::vector<std::string> GetRegisteredLayers(); +class RNN : public Layer { + /* + public: + void Setup(const std::vector<size_t>& in_sample_shape, + const std::string& proto_str) override; + */ +}; +class CudnnRNN : public RNN { + public: + // note: Must use std::vector instead of vector. + const std::vector<Tensor> Forward(int flag, const std::vector<Tensor>& inputs) override; + const std::pair<std::vector<Tensor>, std::vector<Tensor>> Backward( + int flag, const std::vector<Tensor>& grads) override; + void ToDevice(std::shared_ptr<Device> device) override; + const std::vector<Tensor> param_values() override; + const std::vector<size_t> GetOutputSampleShape() const override; +}; - }; - std::shared_ptr<Layer> CreateLayer(const std::string& type); - const std::vector<std::string> GetRegisteredLayers(); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/dfc422e5/test/singa/test_cudnn_rnn.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_cudnn_rnn.cc b/test/singa/test_cudnn_rnn.cc index ebbf0aa..e0de02e 100644 --- a/test/singa/test_cudnn_rnn.cc +++ b/test/singa/test_cudnn_rnn.cc @@ -21,6 +21,7 @@ #include "../src/model/layer/cudnn_rnn.h" #ifdef USE_CUDNN +#if CUDNN_VERSION_MAJOR >= 5 #include "gtest/gtest.h" @@ -31,15 +32,15 @@ class TestCudnnRNN : public ::testing::Test { protected: virtual void SetUp() { singa::RNNConf *rnnconf = conf.mutable_rnn_conf(); - rnnconf->set_hidden_dim(hidden_dim); + rnnconf->set_hidden_size(hidden_size); rnnconf->set_num_stacks(1); - rnnconf->set_dropout(1); + rnnconf->set_dropout(0); rnnconf->set_input_mode("linear"); rnnconf->set_direction("unidirectional"); rnnconf->set_rnn_mode("tanh"); } singa::LayerConf conf; - size_t hidden_dim = 4; + size_t hidden_size = 4; }; TEST_F(TestCudnnRNN, Setup) { @@ -47,7 +48,7 @@ TEST_F(TestCudnnRNN, Setup) { EXPECT_EQ("CudnnRNN", rnn.layer_type()); rnn.Setup(Shape{2}, conf); auto weight = rnn.param_values().at(0); - EXPECT_EQ(weight.Size(), hidden_dim * (2 + hidden_dim + 2)); + EXPECT_EQ(weight.Size(), hidden_size * (2 + hidden_size + 2)); } TEST_F(TestCudnnRNN, Forward) { @@ -80,19 +81,19 @@ TEST_F(TestCudnnRNN, Forward) { const auto ret = rnn.Forward(singa::kEval, inputs); EXPECT_EQ(ret.size(), seqLength + 1); - vector<float> hxptr(hidden_dim, 0.0f); + vector<float> hxptr(hidden_size, 0.0f); for (size_t i = 0; i < seqLength; i++) { auto y = ret[i]; y.ToHost(); auto yptr = y.data<float>(); vector<float> tmp; - for (size_t j = 0; j < hidden_dim; j++) { + for (size_t j = 0; j < hidden_size; j++) { float ty = 0; for (size_t k = 0; k < dim; k++) { ty += x[i * dim + k] * wvalue; } ty += wvalue; - for (size_t k = 0; k < hidden_dim; k++) { + for (size_t k = 0; k < hidden_size; k++) { ty += hxptr[k] * wvalue; } ty += wvalue; @@ -134,18 +135,18 @@ TEST_F(TestCudnnRNN, Backward) { const auto outs = rnn.Forward(singa::kTrain, inputs); - float dyptr[seqLength * batchsize * hidden_dim]; - for (size_t i = 0; i < seqLength * batchsize * hidden_dim; i++) + float dyptr[seqLength * batchsize * hidden_size]; + for (size_t i = 0; i < seqLength * batchsize * hidden_size; i++) dyptr[i] = i * 0.1f; vector<Tensor> grads; for (size_t i = 0; i < seqLength; i++) { - Tensor dy(Shape{batchsize, hidden_dim}, cuda); + Tensor dy(Shape{batchsize, hidden_size}, cuda); dy.CopyDataFromHostPtr(dyptr + i * dy.Size(), dy.Size()); grads.push_back(dy); } Tensor dhy; grads.push_back(dhy); - vector<float> dhyptr(hidden_dim, 0.0f); + vector<float> dhyptr(hidden_size, 0.0f); const auto ret = rnn.Backward(singa::kTrain, grads); for (size_t i = seqLength - 1; i > 0 ; i --) { auto dx = ret.first[i]; @@ -154,21 +155,21 @@ TEST_F(TestCudnnRNN, Backward) { dx.ToHost(); auto dxptr = dx.data<float>(); auto yptr = y.data<float>(); - for (size_t j = 0; j < hidden_dim; j++) { - dhyptr[j] += dyptr[i * hidden_dim + j]; + for (size_t j = 0; j < hidden_size; j++) { + dhyptr[j] += dyptr[i * hidden_size + j]; dhyptr[j] *= 1 - yptr[j] * yptr[j]; } for (size_t k = 0; k < dim; k++) { float tdx = 0; - for (size_t j = 0; j < hidden_dim; j++) { + for (size_t j = 0; j < hidden_size; j++) { tdx += dhyptr[j] * wvalue; } EXPECT_NEAR(tdx, dxptr[k], 1e-4); } vector<float> tmp; - for (size_t k = 0; k < hidden_dim; k++) { + for (size_t k = 0; k < hidden_size; k++) { float tdhy = 0; - for (size_t j = 0; j < hidden_dim; j++) { + for (size_t j = 0; j < hidden_size; j++) { tdhy += dhyptr[j] * wvalue; } tmp.push_back(tdhy); @@ -176,4 +177,5 @@ TEST_F(TestCudnnRNN, Backward) { std::copy(tmp.begin(), tmp.end(), dhyptr.begin()); } } +#endif // CUDNN_VERSION_MAJOR >= 5 #endif // USE_CUDNN
