Repository: incubator-singa Updated Branches: refs/heads/master c628e900e -> 9eea5b53f
SINGA-324 Extend RNN layer to accept variant seq length across batches The cudnn rnn layer is updated to handle mini-batches with different seq lengths. The internal data structures are re-allocated if max_length_ < seq_length_ (the longest sample of the current mini-batch). Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/2ce7229a Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/2ce7229a Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/2ce7229a Branch: refs/heads/master Commit: 2ce7229ad688f54b7fb88fe46c23e771dfb87365 Parents: 12db1be Author: Wang Wei <[email protected]> Authored: Sat Jun 17 15:27:44 2017 +0800 Committer: Wang Wei <[email protected]> Committed: Sat Jun 17 15:27:44 2017 +0800 ---------------------------------------------------------------------- src/model/layer/cudnn_rnn.cc | 7 ++++--- src/model/layer/rnn.h | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2ce7229a/src/model/layer/cudnn_rnn.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc index 583dcda..62c6355 100644 --- a/src/model/layer/cudnn_rnn.cc +++ b/src/model/layer/cudnn_rnn.cc @@ -60,7 +60,7 @@ void CudnnRNN::ToDevice(std::shared_ptr<Device> device) { void CudnnRNN::DestroyIODescriptors() { if (x_descs_ != nullptr) { - for (size_t i = 0; i < seq_length_; i++) { + for (size_t i = 0; i < max_length_; i++) { CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_descs_[i])); CUDNN_CHECK(cudnnDestroyTensorDescriptor(dx_descs_[i])); } @@ -68,7 +68,7 @@ void CudnnRNN::DestroyIODescriptors() { delete [] dx_descs_; } if (y_descs_ != nullptr) { - for (size_t i = 0; i < seq_length_; i++) { + for (size_t i = 0; i < max_length_; i++) { CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_descs_[i])); CUDNN_CHECK(cudnnDestroyTensorDescriptor(dy_descs_[i])); } @@ -79,8 +79,9 @@ void CudnnRNN::DestroyIODescriptors() { void CudnnRNN::UpdateIODescriptors(size_t len, const vector<Tensor> &inputs) { bool reset = false; - if (seq_length_ < len) { + if (max_length_ < len) { DestroyIODescriptors(); + max_length_ = len; x_descs_ = new cudnnTensorDescriptor_t[len]; dx_descs_ = new cudnnTensorDescriptor_t[len]; y_descs_ = new cudnnTensorDescriptor_t[len]; http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2ce7229a/src/model/layer/rnn.h ---------------------------------------------------------------------- diff --git a/src/model/layer/rnn.h b/src/model/layer/rnn.h index 3369a00..4b442d0 100644 --- a/src/model/layer/rnn.h +++ b/src/model/layer/rnn.h @@ -85,7 +85,7 @@ class RNN : public Layer { std::stack<Tensor> buf_; bool has_cell_ = false; size_t num_directions_ = 1; - size_t input_size_ = 0, hidden_size_ = 0, num_stacks_ = 0, seq_length_ = 0; + size_t input_size_ = 0, hidden_size_ = 0, num_stacks_ = 0, seq_length_ = 0, max_length_ = 0; size_t batch_size_ = 0; size_t seed_ = 0x1234567; float dropout_ = 0.0f;
