SINGA-218 Implementation for RNN CUDNN version

- cudnn rnn implementation (cudnn_rnn,h, cudnn_rnn.cc, rnn.cc, rnn.h, 
test_cudnn_rnn.cc).
- The weight shape now are manually calculated instead of using API provided by 
CUDNN.
- Test for RNN_cudnn_Tanh (unidirectional, 1 hidden layer).


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/c51f9448
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/c51f9448
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/c51f9448

Branch: refs/heads/dev
Commit: c51f9448284ea905db592fb0c09d2bb0e8801828
Parents: 28678ae
Author: zhaojing <[email protected]>
Authored: Sat Jun 25 00:15:06 2016 +0800
Committer: Wei Wang <[email protected]>
Committed: Wed Aug 10 00:43:11 2016 +0800

----------------------------------------------------------------------
 src/model/layer/cudnn_rnn.cc | 328 ++++++++++++++++++++++++++++++++++++++
 src/model/layer/cudnn_rnn.h  |  85 ++++++++++
 src/model/layer/rnn.cc       |  53 ++++++
 src/model/layer/rnn.h        |  31 +++-
 src/proto/model.proto        |  17 ++
 test/singa/test_cudnn_rnn.cc | 212 ++++++++++++++++++++++++
 6 files changed, 720 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c51f9448/src/model/layer/cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.cc b/src/model/layer/cudnn_rnn.cc
new file mode 100644
index 0000000..6f04e5c
--- /dev/null
+++ b/src/model/layer/cudnn_rnn.cc
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "./cudnn_rnn.h"
+#ifdef USE_CUDNN
+#include <cudnn.h>
+#include <chrono>
+#include "./cudnn_utils.h"
+#include "singa/utils/logging.h"
+
+namespace singa {
+CudnnRNN::~CudnnRNN() {
+  if (weight_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyFilterDescriptor(weight_desc_));
+  if (dropout_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc_));
+  if (rnn_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyRNNDescriptor(rnn_desc_));
+  if (x_descs_ != nullptr)
+    for (size_t i = 0; i < seqLength_; i++) 
+      CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_descs_[i]));
+  if (y_descs_ != nullptr)
+    for (size_t i = 0; i < seqLength_; i++) 
+      CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_descs_[i]));
+  if (hx_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(hx_desc_));
+  if (hy_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(hy_desc_));
+  if (cx_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(cx_desc_));
+  if (cy_desc_ != nullptr)
+    CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc_));
+}
+
+void CudnnRNN::Setup(const Shape& in_sample, const LayerConf &conf) {
+  RNN::Setup(in_sample, conf);
+  RNNConf rnn_conf = conf.rnn_conf();
+  // convert MB to bytes
+  workspace_byte_limit_ = rnn_conf.workspace_byte_limit() << 20;
+  inputMode_ = ToLowerCase(rnn_conf.inputmode());
+  direction_ = ToLowerCase(rnn_conf.direction());
+  mode_ = ToLowerCase(rnn_conf.mode());
+  CHECK(inputMode_ == "cudnn_linear_input" || inputMode_ == "cudnn_skip_input")
+      << "CudnnRNN only supports two inputmodes: cudnn_linear_input, "
+         "cudnn_skip_input";
+  CHECK(direction_ == "cudnn_undirectional" || direction_ == 
"cudnn_bidirectional")
+      << "CudnnRNN only supports two directions: cudnn_undirectional, "
+         "cudnn_bidirectional";
+  CHECK(mode_ == "cudnn_rnn_relu" || mode_ == "cudnn_rnn_tanh" ||
+        mode_ == "cudnn_lstm" || mode_ == "cudnn_gru")
+      << "CudnnRNN only supports four modes: cudnn_rnn_relu, "
+         "cudnn_rnn_tanh, cudnn_lstm and cudnn_gru";
+  // the first constant (4) is the size of float
+  // the second constant (2, 8, 6) is the number of sets of params
+  if (mode_ == "cudnn_rnn_relu" || mode_ == "cudnn_rnn_tanh")
+    weightSize_ = 4 * 2 * (hiddenSize_ * in_sample[2] + hiddenSize_);
+  else if (mode_ == "cudnn_lstm")
+    weightSize_ = 4 * 8 * (hiddenSize_ * in_sample[2] + hiddenSize_);
+  else if (mode_ == "cudnn_gru")
+    weightSize_ = 4 * 6 * (hiddenSize_ * in_sample[2] + hiddenSize_);
+  if (direction_ == "cudnn_bidirectional")
+    weightSize_ = weightSize_ * 2;
+}
+
+void CudnnRNN::ToDevice(std::shared_ptr<Device> device) {
+  weight_.ToDevice(device);
+  workspace_.ToDevice(device);
+}
+
+void CudnnRNN::InitCudnn(const Tensor &input) {
+  CHECK(!has_init_cudnn_);
+  DataType dtype = input.data_type();
+  auto dev = input.device();
+  Context *ctx = dev->context(0);
+  seqLength_ = input.shape(0);
+  size_t batchsize = input.shape(1); /*(seqLength, minibatch, inputSize) !!! */
+  size_t inputSize = input.shape(2);
+  size_t numDirections;
+  if (direction_ == "cudnn_undirectional")
+    numDirections = 1;
+  else 
+    numDirections = 2;
+  x_descs_ = new cudnnTensorDescriptor_t[seqLength_];
+  y_descs_ = new cudnnTensorDescriptor_t[seqLength_];
+  for (size_t i = 0; i < seqLength_; i++)
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_descs_[i]));
+  for (size_t i = 0; i < seqLength_; i++)
+    CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_descs_[i]));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&hx_desc_));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc_));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&hy_desc_));
+  CUDNN_CHECK(cudnnCreateTensorDescriptor(&cy_desc_));
+  CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc_));
+  CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc_));
+  CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_));
+
+
+  int dimA[3] = {batchsize, inputSize, 1};
+  int strideA[3] = {dimA[2] * dimA[1], dimA[2], 1};
+  for (size_t i = 0; i < seqLength_; i++){
+    dimA[0] = batchsize;
+    dimA[1] = inputSize;
+    dimA[2] = 1;
+    strideA[0] = dimA[2] * dimA[1];
+    strideA[1] = dimA[2];
+    strideA[2] = 1;
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_descs_[i], 
GetCudnnDataType(dtype), 3,
+                                         dimA, strideA));
+    dimA[0] = batchsize;
+    dimA[1] = hiddenSize_ * numDirections;
+    dimA[2] = 1;
+    strideA[0] = dimA[2] * dimA[1];
+    strideA[1] = dimA[2];
+    strideA[2] = 1;
+    CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_descs_[i], 
GetCudnnDataType(dtype), 3,
+                                         dimA, strideA));
+  }
+  
+  dimA[0] = numLayers_;
+  dimA[1] = batchsize;
+  dimA[2] = hiddenSize_ * numDirections;
+  strideA[0] = dimA[2] * dimA[1];
+  strideA[1] = dimA[2];
+  strideA[2] = 1;
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(hx_desc_, GetCudnnDataType(dtype), 3,
+                                         dimA, strideA));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(cx_desc_, GetCudnnDataType(dtype), 3,
+                                         dimA, strideA));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(hy_desc_, GetCudnnDataType(dtype), 3,
+                                         dimA, strideA));
+  CUDNN_CHECK(cudnnSetTensorNdDescriptor(cy_desc_, GetCudnnDataType(dtype), 3,
+                                         dimA, strideA));
+
+  size_t dropoutStatesSize;
+  CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, 
&dropoutStatesSize));
+  dropoutStates_ = Tensor(Shape{dropoutStatesSize}, dev, dtype);
+  CUDNN_CHECK(cudnnSetDropoutDescriptor(dropout_desc_, ctx->cudnn_handle, 
dropout_, this->dropoutStates_.block()->mutable_data(), dropoutStatesSize, 
0x01234567));
+  
+  cudnnRNNInputMode_t inputMode;
+  cudnnDirectionMode_t direction;
+  cudnnRNNMode_t mode;
+  
+  if (inputMode_ == "cudnn_linear_input" || inputMode_ == "cudnn_skip_input"){
+    if (inputMode_ == "cudnn_linear_input")
+      inputMode = CUDNN_LINEAR_INPUT;
+    else if (inputMode_ == "cudnn_skip_input")
+      inputMode = CUDNN_SKIP_INPUT;
+  }
+  if (direction_ == "cudnn_undirectional" || direction_ == 
"cudnn_bidirectional"){
+    if (direction_ == "cudnn_undirectional")
+      direction = CUDNN_UNIDIRECTIONAL;
+    else if (direction_ == "cudnn_bidirectional")
+      direction = CUDNN_BIDIRECTIONAL;
+  }
+  if (mode_ == "cudnn_rnn_relu" || mode_ == "cudnn_rnn_tanh" ||
+        mode_ == "cudnn_lstm" || mode_ == "cudnn_gru"){
+    if (mode_ == "cudnn_rnn_relu")
+      mode = CUDNN_RNN_RELU;
+    else if (mode_ == "cudnn_rnn_tanh")
+      mode = CUDNN_RNN_TANH;
+    else if (mode_ == "cudnn_lstm")
+      mode = CUDNN_LSTM;
+    else if (mode_ == "cudnn_gru")
+      mode = CUDNN_GRU;
+  }
+  CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hiddenSize_, numLayers_, 
dropout_desc_, inputMode, direction, mode, GetCudnnDataType(dtype)));
+
+  size_t weightSize;
+  CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0], 
&weightSize, GetCudnnDataType(dtype)));
+  CHECK_EQ(weightSize, weightSize_);
+
+  int filterDimA[3] = {weightSize_, 1, 1};
+  CUDNN_CHECK(cudnnSetFilterNdDescriptor(weight_desc_, 
GetCudnnDataType(dtype), CUDNN_TENSOR_NCHW, 3, filterDimA));
+
+  
+  CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnn_desc_, 
seqLength_, x_descs_, &workspace_count_));
+  workspace_ = Tensor(Shape{workspace_count_}, dev, dtype);
+
+  CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnn_desc_, 
seqLength_, x_descs_, &ReserveSize_));
+  reserve_ = Tensor(Shape{ReserveSize_}, dev, dtype);
+  has_init_cudnn_ = true;
+}
+
+const vector<Tensor> CudnnRNN::Forward(int flag, const vector<Tensor>& inputs) 
{
+  /*(seqLength, minibatch, inputSize) !!! */
+  singa::Tensor input = inputs[0];
+  singa::Tensor hx = inputs[1];
+  singa:: Tensor cx = inputs[2];
+  CHECK_EQ(input.device()->lang(), kCuda);
+  CHECK_EQ(input.device()->lang(), this->weight_.device()->lang());
+  CHECK_EQ(input.nDim(), 3u);
+  vector<Tensor> data_output;
+  if (flag & kTrain) buf_.push(input);  // buffer the input for backward
+  size_t batchsize = input.shape(1); /*(seqLength, minibatch, inputSize) !!! */
+  DataType dtype = input.data_type();
+  auto dev = input.device();
+ 
+  if (!has_init_cudnn_) InitCudnn(input);
+ 
+    
+  size_t numDirections;
+  if (direction_ == "cudnn_undirectional")
+    numDirections = 1;
+  else 
+    numDirections = 2;
+  
+  Shape shape{seqLength_, batchsize, hiddenSize_ * numDirections};
+  Tensor output(shape, dev, dtype);
+  Shape shape1{numLayers_, batchsize, hiddenSize_ * numDirections};
+  Tensor hy(shape1, dev, dtype);
+  Tensor cy(shape1, dev, dtype);
+  
+  output.device()->Exec([input, output, hx, hy, cx, cy, this](Context *ctx) {
+    Block *inblock = input.block(), *outblock = output.block(),
+          *wblock = this->weight_.block(), *hxblock = hx.block(), 
+          *hyblock = hy.block(), *cxblock = cx.block(), 
+          *cyblock = cy.block();
+    cudnnRNNForwardTraining(
+        ctx->cudnn_handle, this->rnn_desc_, seqLength_, this->x_descs_, 
+        inblock->data(), this->hx_desc_, hxblock->data(), this->cx_desc_, 
+        cxblock->data(), this->weight_desc_, wblock->data(), this->y_descs_, 
+        outblock->mutable_data(), this->hy_desc_, hyblock->mutable_data(), 
+        cy_desc_, cyblock->mutable_data(), 
this->workspace_.block()->mutable_data(), 
+        this->workspace_count_ * sizeof(float), 
this->reserve_.block()->mutable_data(), 
+        this->ReserveSize_ * sizeof(float));
+}, {input.block(), weight_.block(), hx.block(), cx.block()}, 
+   {output.block(), hy.block(), cy.block()}, workspace_.block());
+  buf_.push(output);
+  buf_.push(hx);
+  buf_.push(hy);  // in order to assign shape to dhy
+  buf_.push(cx);
+  buf_.push(cy);  // in order to assign shape to dcy
+  data_output.push_back(output);
+  data_output.push_back(hy);
+  data_output.push_back(cy);
+  return data_output;
+}
+
+const std::pair<vector<Tensor>, vector<Tensor>> CudnnRNN::Backward(
+    int flag, const vector<Tensor>& grads) {
+  CHECK(has_init_cudnn_);
+  singa::Tensor grad = grads[0];
+  singa::Tensor dhy = grads[1];
+  singa::Tensor dcy = grads[2];
+  CHECK_EQ(grad.device()->lang(), kCuda);
+  CHECK_EQ(grad.nDim(), 3u);
+  CHECK(!buf_.empty());
+  Tensor cy = buf_.top();
+  buf_.pop();
+  CHECK(!buf_.empty());
+  Tensor cx = buf_.top();
+  buf_.pop();
+  CHECK(!buf_.empty());
+  Tensor hy = buf_.top();
+  buf_.pop();
+  CHECK(!buf_.empty());
+  Tensor hx = buf_.top();
+  buf_.pop();
+  CHECK(!buf_.empty());
+  Tensor src_output = buf_.top();
+  buf_.pop();
+  CHECK(!buf_.empty());
+  Tensor src_data = buf_.top();
+  buf_.pop();
+  vector<Tensor> param_grad;
+  vector<Tensor> data_grad;
+  Tensor dx;
+  dx.ResetLike(src_data);
+  Tensor dw;
+  dw.ResetLike(weight_);
+  Tensor dhx;
+  dhx.ResetLike(hx);
+  Tensor dcx;
+  dcx.ResetLike(cx);
+
+
+  dx.device()->Exec([grad, dw, src_data, src_output, hx, this](Context *ctx) {
+    Block *inblock = src_data.block(), *srcoutblock = src_output.block(), 
+          *dwblock = dw.block(), *hxblock = hx.block();
+    cudnnRNNBackwardWeights(
+        ctx->cudnn_handle, this->rnn_desc_, seqLength_, this->x_descs_, 
+        inblock->data(), this->hx_desc_, hxblock->data(), this->y_descs_, 
+        srcoutblock->data(), this->workspace_.block()->mutable_data(), 
+        this->workspace_count_ * sizeof(float), this->weight_desc_, 
+        dwblock->mutable_data(), this->reserve_.block()->mutable_data(), 
+        this->ReserveSize_ * sizeof(float));
+  }, {src_data.block(), hx.block(), src_output.block()}, {dw.block(), 
workspace_.block()}); 
+  
+  // LOG(ERROR) << "backward src";
+  dx.device()->Exec([grad, dw, src_output, dx, cy, cx, hy, hx, dhy, dcy, dhx, 
dcx, this](Context *ctx) {
+    Block *srcoutblock = src_output.block(), *wblock = this->weight_.block(), 
*dxblock = dx.block(),
+          *dyblock = grad.block(), *cxblock = cx.block(), *hxblock = 
hx.block(), *dhyblock = dhy.block(),
+          *dcyblock = dcy.block(), *dhxblock = dhx.block(), *dcxblock = 
dcx.block();
+    cudnnRNNBackwardData(
+        ctx->cudnn_handle, this->rnn_desc_, seqLength_, this->y_descs_, 
srcoutblock->data(), 
+        this->y_descs_, dyblock->data(), this->hy_desc_, dhyblock->data(), 
+        this->cy_desc_, dcyblock->data(), this->weight_desc_, wblock->data(), 
+        this->hx_desc_, hxblock->data(), this->cx_desc_, cxblock->data(), 
+        this->x_descs_, dxblock->mutable_data(), this->hx_desc_, 
dhxblock->mutable_data(), 
+        this->cx_desc_, dcxblock->mutable_data(), 
this->workspace_.block()->mutable_data(), 
+        this->workspace_count_ * sizeof(float), 
this->reserve_.block()->mutable_data(), 
+        this->ReserveSize_ * sizeof(float));
+  }, {hx.block(), src_output.block(), grad.block(), grad.block(), dhy.block(), 
dcy.block(), 
+      this->weight_.block(), hx.block(), cx.block()}, 
+     {dx.block(), dhx.block(), dcx.block(), reserve_.block(), 
workspace_.block()}); 
+  param_grad.push_back(dw);
+  data_grad.push_back(dx);
+  data_grad.push_back(dhx);
+  data_grad.push_back(dcx);
+  return std::make_pair(data_grad, param_grad);
+}
+
+}  // namespace singa
+#endif  // USE_CUDNN

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c51f9448/src/model/layer/cudnn_rnn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/cudnn_rnn.h b/src/model/layer/cudnn_rnn.h
new file mode 100644
index 0000000..b1e9f43
--- /dev/null
+++ b/src/model/layer/cudnn_rnn.h
@@ -0,0 +1,85 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SRC_MODEL_LAYER_CUDNN_RNN_H_
+#define SRC_MODEL_LAYER_CUDNN_RNN_H_
+#include "singa/singa_config.h"
+#ifdef USE_CUDNN
+#include <string>
+#include <utility>
+#include <vector>
+#include "./rnn.h"
+#include "singa/core/common.h"
+#include "singa/model/layer.h"
+#include "singa/proto/core.pb.h"
+#include "singa/utils/string.h"
+#include <cudnn.h>
+#include <chrono>
+#include "./cudnn_utils.h"
+#include "singa/utils/logging.h"
+
+namespace singa {
+class CudnnRNN : public RNN {
+ public:
+  ~CudnnRNN();
+  /// \copydoc Layer::layer_type()
+  const std::string layer_type() const override { return "CudnnRNN"; }
+
+  const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) 
override;
+  const std::pair<vector<Tensor>, vector<Tensor>> Backward(int flag, const 
vector<Tensor>& grads) override;
+
+  /// \copydoc Layer::Setup(const LayerConf&);
+  void Setup(const Shape& in_sample, const LayerConf &conf) override;
+
+  void ToDevice(std::shared_ptr<Device> device) override;
+
+  size_t workspace_byte_limit() { return workspace_byte_limit_; }
+  // string prefer() { return prefer_; }
+  string inputMode() const { return inputMode_; }
+  string direction() const { return direction_; }
+  string mode() const { return mode_; }
+
+ protected:
+  /// Init cudnn related data structures.
+  void InitCudnn(const Tensor& input);
+
+ protected:
+  bool has_init_cudnn_ = false;
+  cudnnTensorDescriptor_t* x_descs_ = nullptr;
+  cudnnTensorDescriptor_t* y_descs_ = nullptr;
+  cudnnTensorDescriptor_t hx_desc_ = nullptr;
+  cudnnTensorDescriptor_t cx_desc_ = nullptr;
+  cudnnTensorDescriptor_t hy_desc_ = nullptr;
+  cudnnTensorDescriptor_t cy_desc_ = nullptr;
+  cudnnFilterDescriptor_t weight_desc_ = nullptr;
+  cudnnRNNDescriptor_t rnn_desc_ = nullptr;
+  cudnnDropoutDescriptor_t dropout_desc_ = nullptr;
+  size_t workspace_byte_limit_, workspace_count_;
+  size_t ReserveSize_;
+  Tensor workspace_;
+  string inputMode_;
+  string direction_;
+  string mode_;
+  Tensor reserve_;
+  Tensor dropoutStates_;
+};
+
+}  // namespace singa
+
+#endif  // USE_CUDNN
+#endif  // SRC_MODEL_LAYER_CUDNN_RNN_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c51f9448/src/model/layer/rnn.cc
----------------------------------------------------------------------
diff --git a/src/model/layer/rnn.cc b/src/model/layer/rnn.cc
new file mode 100644
index 0000000..493a5e4
--- /dev/null
+++ b/src/model/layer/rnn.cc
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "./rnn.h"
+#include <vector>
+#include "singa/model/layer.h"
+
+namespace singa {
+
+void RNN::Setup(const Shape& in_sample, const LayerConf &conf) {
+  Layer::Setup(in_sample, conf);
+  RNNConf rnn_conf = conf.rnn_conf();
+  hiddenSize_ = rnn_conf.hiddensize();
+  CHECK_GT(hiddenSize_, 0u);
+
+  numLayers_ = rnn_conf.numlayers();
+  CHECK_GT(numLayers_, 0u);
+
+  dropout_ = rnn_conf.dropout();
+  CHECK_GE(dropout_, 0u);
+}
+
+const vector<Tensor> RNN::Forward(int flag, const vector<Tensor>& inputs) {
+  vector<Tensor> data_output;
+  return data_output;
+}
+
+const std::pair<vector<Tensor>, vector<Tensor>> RNN::Backward(int flag, const 
vector<Tensor>& grads) {
+  vector<Tensor> param_grad;
+  vector<Tensor> data_grad;
+  return std::make_pair(data_grad, param_grad);
+}
+
+void RNN::ToDevice(std::shared_ptr<Device> device) {
+  Layer::ToDevice(device);
+  weight_.ToDevice(device);
+}
+}  /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c51f9448/src/model/layer/rnn.h
----------------------------------------------------------------------
diff --git a/src/model/layer/rnn.h b/src/model/layer/rnn.h
index 35c86bd..ec5a35d 100644
--- a/src/model/layer/rnn.h
+++ b/src/model/layer/rnn.h
@@ -38,21 +38,32 @@ class RNN : public Layer {
   const std::string layer_type() const override { return "RNN"; }
 
   /// \copydoc Layer::Setup(const LayerConf&);
-  void Setup(const LayerConf& conf) override;
+  void Setup(const vector<size_t>& in_shape, const LayerConf& conf) override;
 
   /// \copydoc Layer::Forward(int flag, const vector<Tensor>&)
-  const vector<Tensor> Forward(int flag, const vector<Tensor>& input) override;
+  const vector<Tensor> Forward(int flag, const vector<Tensor>& inputs) 
override;
 
   /// \copydoc Layer::Backward(int, const vector<Tensor>&);
   const std::pair<vector<Tensor>, vector<Tensor>> Backward(
-      int flag, const vector<Tensor>& grad) override;
+      int flag, const vector<Tensor>& grads) override;
 
-  void ToDevice(Device* device) override;
 
+  size_t hiddenSize() const { return hiddenSize_; }
+  size_t numLayers() const { return numLayers_; }
+  size_t weightSize() const { return weightSize_; }
+  float dropout() const { return dropout_; }
+  
+  void set_weight(Tensor w) {
+    weight_.ResetLike(w);
+    weight_.CopyData(w);
+  }
+
+
+  void ToDevice(std::shared_ptr<Device> device) override;
   /// Return the internal state stack, which should be empty at the beginning
   /// of
   /// one iteration.
-  std::stack<Tensor> states() const { return states_; }
+  // std::stack<Tensor> states() const { return states_; }
 
  protected:
   /// Storing input or output from Forward(), which are used in Backward().
@@ -60,7 +71,15 @@ class RNN : public Layer {
   /// 1. push the 'input' or 'output' into states_ if the flag of Forward() is
   ///    for kTrain and 'input' or 'output' is necessary for Backward().
   /// 2. pop data out in Backward().
-  std::stack<Tensor*> states_;
+  // std::stack<Tensor*> states_;
+  std::stack<Tensor> buf_;
+  size_t hiddenSize_;
+  size_t numLayers_;
+  size_t numLinearLayer_;
+  size_t seqLength_;
+  size_t weightSize_; /*all the weights and biases*/
+  float dropout_;
+  Tensor weight_;
 };
 }  // namespace singa
 #endif  // SRC_MODEL_LAYER_RNN_H_

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c51f9448/src/proto/model.proto
----------------------------------------------------------------------
diff --git a/src/proto/model.proto b/src/proto/model.proto
index b1318d9..d8193f1 100644
--- a/src/proto/model.proto
+++ b/src/proto/model.proto
@@ -203,6 +203,7 @@ message LayerConf {
   optional ConcatConf concat_conf = 104;
   optional ContrastiveLossConf contrastive_loss_conf = 105;
   optional ConvolutionConf convolution_conf = 106;
+  optional RNNConf rnn_conf = 140;
   // optional DataConf data_conf = 107;
   optional DropoutConf dropout_conf = 108;
   // optional DummyDataConf dummy_data_conf = 109;
@@ -391,6 +392,22 @@ message ConvolutionConf {
   optional string prefer = 51 [default = "fastest"];
 }
 
+message RNNConf {
+  optional uint32 hiddensize = 1; // The number of hiddensize
+  optional uint32 numlayers = 2; // The number of stacked RNN layers
+  optional float dropout = 3 [default = 0];
+  optional int32 workspace_byte_limit = 50 [default = 512];
+  // cudnn inputmode
+  // options: "cudnn_linear_input", "cudnn_skip_input"
+  optional string inputmode = 51 [default = "cudnn_linear_input"];
+  // cudnn direction
+  // options: "cudnn_undirectional", "cudnn_bidirectional"
+  optional string direction = 52 [default = "cudnn_undirectional"];
+  // cudnn RNN mode
+  // options: "cudnn_rnn_relu", "cudnn_rnn_tanh", "cudnn_lstm", "cudnn_gru"
+  optional string mode = 53 [default = "cudnn_rnn_relu"];
+}
+
 /*
 message DataConf {
   enum DB {

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/c51f9448/test/singa/test_cudnn_rnn.cc
----------------------------------------------------------------------
diff --git a/test/singa/test_cudnn_rnn.cc b/test/singa/test_cudnn_rnn.cc
new file mode 100644
index 0000000..1a79d7b
--- /dev/null
+++ b/test/singa/test_cudnn_rnn.cc
@@ -0,0 +1,212 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+#include "../src/model/layer/cudnn_rnn.h"
+#ifdef USE_CUDNN
+
+#include "gtest/gtest.h"
+
+using singa::CudnnRNN;
+using singa::Shape;
+TEST(CudnnRNN, Setup) {
+  CudnnRNN rnn;
+  EXPECT_EQ("CudnnRNN", rnn.layer_type());
+
+  singa::LayerConf conf;
+  singa::RNNConf *rnnconf = conf.mutable_rnn_conf();
+  rnnconf->set_hiddensize(2);
+  rnnconf->set_numlayers(1);
+  rnnconf->set_dropout(0); 
+  rnnconf->set_inputmode("cudnn_linear_input");
+  rnnconf->set_direction("cudnn_undirectional");
+  rnnconf->set_mode("cudnn_rnn_tanh");
+  // MB
+  rnnconf->set_workspace_byte_limit(256);
+  rnn.Setup(Shape{4, 1, 2}, conf);
+
+  EXPECT_EQ(2u, rnn.hiddenSize());
+  EXPECT_EQ(1u, rnn.numLayers());
+  EXPECT_EQ(0u, rnn.dropout());
+  EXPECT_EQ("cudnn_linear_input", rnn.inputMode());
+  EXPECT_EQ("cudnn_undirectional", rnn.direction());
+  EXPECT_EQ("cudnn_rnn_tanh", rnn.mode());
+  EXPECT_EQ(256u << 20, rnn.workspace_byte_limit());
+}
+
+TEST(CudnnRNN, Forward) {
+  auto cuda = std::make_shared<singa::CudaGPU>();
+  const size_t seqLength = 4, batchsize = 1, dim = 2;
+  const size_t numLayers = 1, hiddensize = 2, numDirections = 1;
+  const float x[seqLength * batchsize * dim] = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
+                                          1.0f, 1.0f, 1.0f};
+  singa::Tensor in(singa::Shape{seqLength, batchsize, dim}, cuda);
+  in.CopyDataFromHostPtr(x, seqLength * batchsize * dim);
+
+
+  
+  const float hx_data[numLayers * batchsize * hiddensize * numDirections] = 
{1.0f, 1.0f};
+  singa::Tensor hx(singa::Shape{numLayers, batchsize, hiddensize * 
numDirections}, cuda);
+  hx.CopyDataFromHostPtr(hx_data, numLayers * batchsize * hiddensize * 
numDirections);
+
+  const float cx_data[numLayers * batchsize * hiddensize * numDirections] = 
{1.0f, 1.0f};
+  singa::Tensor cx(singa::Shape{numLayers, batchsize, hiddensize * 
numDirections}, cuda);
+  cx.CopyDataFromHostPtr(cx_data, numLayers * batchsize * hiddensize * 
numDirections);
+  
+  CudnnRNN rnn;
+  
+  singa::LayerConf conf;
+  singa::RNNConf *rnnconf = conf.mutable_rnn_conf();
+  rnnconf->set_hiddensize(2);
+  rnnconf->set_numlayers(1);
+  rnnconf->set_dropout(0);
+  rnnconf->set_inputmode("cudnn_linear_input");
+  rnnconf->set_direction("cudnn_undirectional");
+  rnnconf->set_mode("cudnn_rnn_tanh");
+  // MB
+  rnnconf->set_workspace_byte_limit(256);
+  rnn.Setup(Shape{4, 1, 2}, conf);
+ 
+  
+  size_t weightSize = rnn.weightSize();
+  float we[weightSize];
+  for (size_t i = 0; i < weightSize; i++)
+    we[i] = 1.0f;
+  singa::Tensor weight(singa::Shape{weightSize, 1, 1}, cuda);
+  weight.CopyDataFromHostPtr(we, weightSize);
+  rnn.set_weight(weight);
+ 
+  vector<singa::Tensor> input_array;
+  input_array.push_back(in);
+  input_array.push_back(hx);
+  input_array.push_back(cx);
+  const auto ret = rnn.Forward(singa::kTrain, input_array);
+  // singa::CppCPU host(0, 1);
+  singa::Tensor out1 = ret[0];
+  out1.ToHost();
+  const float *outptr1 = out1.data<float>();
+  EXPECT_EQ(8u, out1.Size());
+  EXPECT_NEAR(1.0f, outptr1[0], 0.0001); // tanh 6
+  EXPECT_NEAR(1.0f, outptr1[1], 0.0001);
+  EXPECT_NEAR(1.0f, outptr1[2], 0.0001);
+  EXPECT_NEAR(1.0f, outptr1[3], 0.0001);
+  EXPECT_NEAR(1.0f, outptr1[4], 0.0001);
+  EXPECT_NEAR(1.0f, outptr1[5], 0.0001);
+  EXPECT_NEAR(1.0f, outptr1[6], 0.0001);
+  EXPECT_NEAR(1.0f, outptr1[7], 0.0001);
+
+  singa::Tensor hy1 = ret[1];
+  hy1.ToHost();
+  const float *hyptr1 = hy1.data<float>();
+  EXPECT_EQ(2u, hy1.Size());
+  EXPECT_NEAR(1.0f, hyptr1[0], 0.0001);
+  EXPECT_NEAR(1.0f, hyptr1[1], 0.0001);
+}
+
+TEST(CudnnRNN, Backward) {
+  // src_data
+  auto cuda = std::make_shared<singa::CudaGPU>();
+  const size_t seqLength = 4, batchsize = 1, dim = 2;
+  const size_t numLayers = 1, hiddensize = 2, numDirections = 1;
+  const float x[seqLength * batchsize * dim] = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
+                                          1.0f, 1.0f, 1.0f};
+  singa::Tensor in(singa::Shape{seqLength, batchsize, dim}, cuda);
+  in.CopyDataFromHostPtr(x, seqLength * batchsize * dim);
+
+  const float hx_data[numLayers * batchsize * hiddensize * numDirections] = 
{1.0f, 1.0f};
+  singa::Tensor hx(singa::Shape{numLayers, batchsize, hiddensize * 
numDirections}, cuda);
+  hx.CopyDataFromHostPtr(hx_data, numLayers * batchsize * hiddensize * 
numDirections);
+
+  const float cx_data[numLayers * batchsize * hiddensize * numDirections] = 
{1.0f, 1.0f};
+  singa::Tensor cx(singa::Shape{numLayers, batchsize, hiddensize * 
numDirections}, cuda);
+  cx.CopyDataFromHostPtr(cx_data, numLayers * batchsize * hiddensize * 
numDirections);
+
+  CudnnRNN rnn;
+
+  singa::LayerConf conf;
+  singa::RNNConf *rnnconf = conf.mutable_rnn_conf();
+  rnnconf->set_hiddensize(2);
+  rnnconf->set_numlayers(1);
+  rnnconf->set_dropout(0);
+  rnnconf->set_inputmode("cudnn_linear_input");
+  rnnconf->set_direction("cudnn_undirectional");
+  rnnconf->set_mode("cudnn_rnn_tanh");
+  // MB
+  rnnconf->set_workspace_byte_limit(256);
+  rnn.Setup(Shape{4, 1, 2}, conf);
+
+  size_t weightSize = rnn.weightSize();
+  float we[weightSize];
+  for (size_t i = 0; i < weightSize; i++)
+    we[i] = 1.0f;
+  singa::Tensor weight(singa::Shape{weightSize, 1, 1}, cuda);
+  weight.CopyDataFromHostPtr(we, weightSize);
+  rnn.set_weight(weight);
+
+
+  vector<singa::Tensor> input_array;
+  input_array.push_back(in);
+  input_array.push_back(hx);
+  input_array.push_back(cx);
+  const auto ret = rnn.Forward(singa::kTrain, input_array);
+
+  // grad
+  const float dy[seqLength * batchsize * hiddensize * numDirections] = {1.0f, 
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
+  singa::Tensor grad(singa::Shape{seqLength, batchsize, hiddensize * 
numDirections},
+                     cuda);
+  grad.CopyDataFromHostPtr(dy, seqLength * batchsize * hiddensize * 
numDirections);
+
+  const float dhy_data[numLayers * batchsize * hiddensize * numDirections] = 
{1.0f, 1.0f};
+  singa::Tensor dhy(singa::Shape{numLayers, batchsize, hiddensize * 
numDirections},
+                     cuda);
+  dhy.CopyDataFromHostPtr(dhy_data, numLayers * batchsize * hiddensize * 
numDirections);
+
+  const float dcy_data[numLayers * batchsize * hiddensize * numDirections] = 
{1.0f, 1.0f};
+  singa::Tensor dcy(singa::Shape{numLayers, batchsize, hiddensize * 
numDirections},
+                     cuda);
+  dcy.CopyDataFromHostPtr(dcy_data, numLayers * batchsize * hiddensize * 
numDirections);
+
+  vector<singa::Tensor> grad_array;
+  grad_array.push_back(grad);
+  grad_array.push_back(dhy);
+  grad_array.push_back(dcy);
+  const auto ret_back = rnn.Backward(singa::kTrain, grad_array);
+  // singa::CppCPU host(0, 1);
+  singa::Tensor in_grad = ret_back.first[0];
+  in_grad.ToHost();
+  const float *dx = in_grad.data<float>();
+  EXPECT_EQ(8u, in_grad.Size());
+  EXPECT_NEAR(0.14, dx[0], 0.0001);
+  EXPECT_NEAR(0.14, dx[1], 0.0001);
+  EXPECT_NEAR(0.1596, dx[2], 0.0001);
+  EXPECT_NEAR(0.1596, dx[3], 0.0001);
+  EXPECT_NEAR(0.1623, dx[4], 0.0001);
+  EXPECT_NEAR(0.1623, dx[5], 0.0001);
+  EXPECT_NEAR(0.1627, dx[6], 0.0001);
+  EXPECT_NEAR(0.1627, dx[7], 0.0001);
+
+  singa::Tensor dhx_grad = ret_back.first[1];
+  dhx_grad.ToHost();
+  const float *dhx = dhx_grad.data<float>();
+  EXPECT_EQ(2u, dhx_grad.Size());
+  EXPECT_NEAR(0.1627, dhx[0], 0.0001);
+  EXPECT_NEAR(0.1627, dhx[1], 0.0001);
+}
+#endif  // USE_CUDNN

Reply via email to