This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1c49e40  Change RNN OP to stateful (#14476)
1c49e40 is described below

commit 1c49e40fdb20fc52a244e45605fbb1d04e283918
Author: Hao Li <[email protected]>
AuthorDate: Sat Apr 13 11:24:00 2019 +0800

    Change RNN OP to stateful (#14476)
    
    * change RNN OP to stateful
    
    * retrigger the ci
    
    * fix windows compile issue
    
    * move cudnnrnn class into rnn-inl.h
    
    * fix some bugs
    
    * fix bug about gpu case like 
tests/python/gpu/test_gluon_gpu.test_rnn_layers_fp32 etc
    
    * fix gpu compile issue of unix-gpu and windows-gpu
    
    * print log for test
    
    * fix GPU NO CUDNN for unix-gpu case
    
    * print log
    
    * remove print log and make gpu case has same error message as master when 
USE_CUDA=1 USE_CUDNN=0
    
    * fix typo bug
    
    * retrigger the ci
    
    * retrigger the ci
    
    * retrigger the ci
    
    * fix comments
    
    * retrigger the ci
    
    * fix comments
    
    * retrigger the ci
    
    * fix comments
    
    * retrigger the ci
    
    * retrigger the ci
    
    * retrigger the ci
    
    * retrigger the ci
    
    * fix comments
---
 src/operator/cudnn_rnn-inl.h |  863 ---------------------------
 src/operator/rnn-inl.h       | 1346 +++++++++++++++++++++++++++++++-----------
 src/operator/rnn.cc          |  198 ++++++-
 src/operator/rnn.cu          |   20 +-
 4 files changed, 1182 insertions(+), 1245 deletions(-)

diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h
deleted file mode 100644
index cc8e4db..0000000
--- a/src/operator/cudnn_rnn-inl.h
+++ /dev/null
@@ -1,863 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2016 by Contributors
- * \file cudnn_rnn-inl.h
- * \brief
- * \author Sebastian Bodenstein
-*/
-#ifndef MXNET_OPERATOR_CUDNN_RNN_INL_H_
-#define MXNET_OPERATOR_CUDNN_RNN_INL_H_
-
-#define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200
-
-#include <mxnet/storage.h>
-#include <vector>
-#include <map>
-#include <string>
-#include <utility>
-#include <cstdint>
-#include "./rnn-inl.h"
-
-namespace mxnet {
-namespace op {
-#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
-template<typename DType>
-class CuDNNRNNOp : public Operator {
- public:
-  explicit CuDNNRNNOp(RNNParam param) {
-    this->param_ = param;
-    init_cudnn_ = false;
-    dtype_ = mshadow::DataType<DType>::kCudnnFlag;
-    // TensorCore algos only allowed on fp16-I/O convolutions if permitted by 
the global policy.
-    // No tests in place for fp16 RNNs, so leave TensorCore disabled for now.
-    cudnn_tensor_core_ = false;
-    // When fp16 RNN tests are introduced, we can enable TensorCore as follows:
-//    cudnn_tensor_core =
-//        mshadow::DataType<DType>::kFlag == mshadow::kFloat16 && 
GetEnvAllowTensorCore();
-    // Defaults
-    input_mode_ = CUDNN_LINEAR_INPUT;  // Don't support this yet
-    // RNN Mode
-    switch (param_.mode) {
-      case rnn_enum::kRnnRelu:
-        mode_ = CUDNN_RNN_RELU;
-        break;
-      case rnn_enum::kRnnTanh:
-        mode_ = CUDNN_RNN_TANH;
-        break;
-      case rnn_enum::kLstm:
-        mode_ = CUDNN_LSTM;
-        break;
-      case rnn_enum::kGru:
-        mode_ = CUDNN_GRU;
-        break;
-      default:
-        LOG(FATAL) << "Not implmented";
-    }
-#if USE_CUDNN_LSTM_PROJ
-    if (param_.projection_size.has_value()) {
-      CHECK_EQ(param_.mode, rnn_enum::kLstm)
-        << "Projection is only supported for LSTM.";
-      CHECK_GE(param_.state_size, param_.projection_size.value())
-        << "State size must be larger than projection size.";
-    }
-#else
-    CHECK(!param_.projection_size.has_value())
-      << "Projection is only supported for LSTM with CuDNN version later than 
7.1.1.";
-#endif
-#if USE_CUDNN_LSTM_PROJ
-    if (param_.lstm_state_clip_min.has_value()
-        || param_.lstm_state_clip_max.has_value()) {
-      CHECK_EQ(param_.mode, rnn_enum::kLstm)
-        << "State clipping is only supported for LSTM.";
-      CHECK(param_.lstm_state_clip_min.has_value() && 
param_.lstm_state_clip_max.has_value())
-        << "lstm_state_clip_min and lstm_state_clip_max must be specified 
together.";
-      CHECK_GE(param_.lstm_state_clip_max.value(), 
param_.lstm_state_clip_min.value())
-        << "lstm_state_clip_max must be greater or equal to 
lstm_state_clip_min";
-    }
-#else
-    CHECK(!param_.lstm_state_clip_min.has_value()
-          && !param_.lstm_state_clip_max.has_value())
-      << "State clipping is only supported for LSTM with CuDNN version later 
than 7.2.1.";
-#endif
-    // RNN Direction
-    direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : 
CUDNN_UNIDIRECTIONAL;
-    // Other
-    if (param_.mode == rnn_enum::kLstm)
-      param_.lstm_q_ = true;
-    else
-      param_.lstm_q_ = false;
-
-    // Create descriptors
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_));
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_));
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_));
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_));
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_));
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_));
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_));
-    CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_));
-
-    CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_));
-    CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_));
-
-    CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
-    CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
-
-    #if USE_CUDNN_LSTM_PROJ
-    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_));
-    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_));
-    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_));
-    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_));
-    #endif
-  }
-
-  ~CuDNNRNNOp() {
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_));
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_));
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(hy_desc_));
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(cy_desc_));
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(dhx_desc_));
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(dcx_desc_));
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(dhy_desc_));
-    CUDNN_CALL(cudnnDestroyTensorDescriptor(dcy_desc_));
-
-    CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc_));
-    CUDNN_CALL(cudnnDestroyFilterDescriptor(dw_desc_));
-    CUDNN_CALL(cudnnDestroyRNNDescriptor(rnn_desc_));
-    CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));
-
-    if (init_cudnn_) {
-      for (size_t i = 0; i < x_desc_vec_.size(); ++i) {
-        CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_vec_[i]));
-        CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_vec_[i]));
-        CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_vec_[i]));
-        CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_vec_[i]));
-      }
-      init_cudnn_ = false;
-
-      Storage::Get()->Free(reserve_space_);
-      if (param_.p > 0) {
-        Storage::Get()->Free(dropout_states_);
-      }
-    }
-    #if USE_CUDNN_LSTM_PROJ
-    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_));
-    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_));
-    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_));
-    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_));
-    #endif
-  }
-
-  virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    size_t in_expected = param_.lstm_q_ ? 4 : 3;
-    size_t out_expected = param_.lstm_q_ ? 3 : 2;
-    if (!param_.state_outputs)
-        out_expected = 1;
-
-    CHECK_EQ(in_data.size(), in_expected);
-    CHECK_EQ(out_data.size(), out_expected);
-    Stream<gpu> *s = ctx.get_stream<gpu>();
-    // get input + output tensors
-    Tensor<gpu, 3, DType> x = in_data[rnn_enum::kData].get<gpu, 3, DType>(s);
-    Tensor<gpu, 1, DType> w = in_data[rnn_enum::kParams].get<gpu, 1, DType>(s);
-    Tensor<gpu, 3, DType> hx = in_data[rnn_enum::kState].get<gpu, 3, DType>(s);
-    Tensor<gpu, 3, DType> y = out_data[rnn_enum::kOut].get<gpu, 3, DType>(s);
-
-    void * hy_ptr = NULL;
-    if (param_.state_outputs)
-      hy_ptr = out_data[rnn_enum::kStateOut].get<gpu, 3, DType>(s).dptr_;
-
-    DType * cx_ptr = NULL;
-    DType * cy_ptr = NULL;
-
-    if (param_.lstm_q_)
-      cx_ptr = (in_data[rnn_enum::kStateCell].get<gpu, 3, DType>(s)).dptr_;
-    if (param_.lstm_q_ && param_.state_outputs)
-      cy_ptr = (out_data[rnn_enum::kStateCellOut].get<gpu, 3, DType>(s)).dptr_;
-
-    CHECK_EQ(x.CheckContiguous(), true);
-    CHECK_EQ(w.CheckContiguous(), true);
-    CHECK_EQ(hx.CheckContiguous(), true);
-    CHECK_EQ(y.CheckContiguous(), true);
-
-    if (!init_cudnn_) {
-      Init(s, in_data, out_data);
-    }
-    // Get temp space
-    int temp_size = workspace_size_;
-    Tensor<gpu, 1, DType> temp_space =
-      ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
-                              mshadow::Shape1(temp_size), s);
-    #if USE_CUDNN_LSTM_PROJ
-    std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_);
-    CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
-                                         dtype_,
-                                         
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
-                                         param_.seq_length_,
-                                         param_.batch_size_,
-                                         param_.input_size_,
-                                         seqLengthArray.data(),
-                                         nullptr));
-    int out_size =
-      (param_.projection_size.has_value()) ? param_.projection_size.value() : 
param_.state_size;
-    out_size = (param_.bidirectional) ? (out_size * 2) : out_size;
-    CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_,
-                                         dtype_,
-                                         
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
-                                         param_.seq_length_,
-                                         param_.batch_size_,
-                                         out_size,
-                                         seqLengthArray.data(),
-                                         nullptr));
-    if (ctx.is_train) {
-      CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_,
-                                           dtype_,
-                                           
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
-                                           param_.seq_length_,
-                                           param_.batch_size_,
-                                           param_.input_size_,
-                                           seqLengthArray.data(),
-                                           nullptr));
-      CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_,
-                                           dtype_,
-                                           
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
-                                           param_.seq_length_,
-                                           param_.batch_size_,
-                                           out_size,
-                                           seqLengthArray.data(),
-                                           nullptr));
-    }
-    #endif
-
-    #if USE_CUDNN_LSTM_PROJ
-    bool clip_state = param_.lstm_state_clip_min.has_value();
-    bool clip_nan = param_.lstm_state_clip_nan;
-    CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_,
-                               rnn_desc_,
-                               clip_state ? CUDNN_RNN_CLIP_MINMAX : 
CUDNN_RNN_CLIP_NONE,
-                               clip_nan ? CUDNN_NOT_PROPAGATE_NAN : 
CUDNN_PROPAGATE_NAN,
-                               clip_state ? param_.lstm_state_clip_min.value() 
: 0.0,
-                               clip_state ? param_.lstm_state_clip_max.value() 
: 0.0));
-    #endif
-
-    if (ctx.is_train) {
-      #if USE_CUDNN_LSTM_PROJ
-      CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
-                                           rnn_desc_,
-                                           x_data_desc_,
-                                           x.dptr_,
-                                           hx_desc_,
-                                           hx.dptr_,
-                                           cx_desc_,
-                                           cx_ptr,
-                                           w_desc_,
-                                           w.dptr_,
-                                           y_data_desc_,
-                                           y.dptr_,
-                                           hy_desc_,
-                                           hy_ptr,
-                                           cy_desc_,
-                                           cy_ptr,
-                                           nullptr,
-                                           nullptr,
-                                           nullptr,
-                                           nullptr,
-                                           nullptr,
-                                           nullptr,
-                                           nullptr,
-                                           nullptr,
-                                           temp_space.dptr_,
-                                           workspace_byte_,
-                                           reserve_space_.dptr,
-                                           reserve_space_byte_));
-      #else
-      CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
-                                         rnn_desc_,
-                                         param_.seq_length_,
-                                         x_desc_vec_.data(),
-                                         x.dptr_,
-                                         hx_desc_,
-                                         hx.dptr_,
-                                         cx_desc_,
-                                         cx_ptr,
-                                         w_desc_,
-                                         w.dptr_,
-                                         y_desc_vec_.data(),
-                                         y.dptr_,
-                                         hy_desc_,
-                                         hy_ptr,
-                                         cy_desc_,
-                                         cy_ptr,
-                                         temp_space.dptr_,
-                                         workspace_byte_,
-                                         reserve_space_.dptr,
-                                         reserve_space_byte_));
-      #endif
-    } else {
-      #if USE_CUDNN_LSTM_PROJ
-      CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
-                                            rnn_desc_,
-                                            x_data_desc_,
-                                            x.dptr_,
-                                            hx_desc_,
-                                            hx.dptr_,
-                                            cx_desc_,
-                                            cx_ptr,
-                                            w_desc_,
-                                            w.dptr_,
-                                            y_data_desc_,
-                                            y.dptr_,
-                                            hy_desc_,
-                                            hy_ptr,
-                                            cy_desc_,
-                                            cy_ptr,
-                                            nullptr,
-                                            nullptr,
-                                            nullptr,
-                                            nullptr,
-                                            nullptr,
-                                            nullptr,
-                                            nullptr,
-                                            nullptr,
-                                            temp_space.dptr_,
-                                            workspace_byte_));
-      #else
-      CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
-                                          rnn_desc_,
-                                          param_.seq_length_,
-                                          x_desc_vec_.data(),
-                                          x.dptr_,
-                                          hx_desc_,
-                                          hx.dptr_,
-                                          cx_desc_,
-                                          cx_ptr,
-                                          w_desc_,
-                                          w.dptr_,
-                                          y_desc_vec_.data(),
-                                          y.dptr_,
-                                          hy_desc_,
-                                          hy_ptr,
-                                          cy_desc_,
-                                          cy_ptr,
-                                          temp_space.dptr_,
-                                          workspace_byte_));
-      #endif
-    }
-  }
-
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    size_t in_expected = param_.lstm_q_ ? 4 : 3;
-    size_t out_expected = param_.lstm_q_ ? 3 : 2;
-    if (!param_.state_outputs)
-      out_expected = 1;
-
-    CHECK_EQ(in_data.size(), in_expected);
-    CHECK_EQ(out_data.size(), out_expected);
-    CHECK_EQ(in_grad.size(), in_expected);
-    CHECK_EQ(out_grad.size(), out_expected);
-    CHECK_EQ(req.size(), in_expected);
-    CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for 
data";
-    CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for 
state";
-    Stream<gpu> *s = ctx.get_stream<gpu>();
-    // get input + output tensors
-    Tensor<gpu, 3, DType> x = in_data[rnn_enum::kData].get<gpu, 3, DType>(s);
-    Tensor<gpu, 3, DType> dx = in_grad[rnn_enum::kData].get<gpu, 3, DType>(s);
-    Tensor<gpu, 1, DType> w = in_data[rnn_enum::kParams].get<gpu, 1, DType>(s);
-    Tensor<gpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<gpu, 1, 
DType>(s);
-    Tensor<gpu, 3, DType> hx = in_data[rnn_enum::kState].get<gpu, 3, DType>(s);
-    Tensor<gpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<gpu, 3, 
DType>(s);
-    Tensor<gpu, 3, DType> y = out_data[rnn_enum::kOut].get<gpu, 3, DType>(s);
-    Tensor<gpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<gpu, 3, DType>(s);
-    if (req[rnn_enum::kParams] != kAddTo) {
-      dw = mshadow::expr::ScalarExp<DType>(0.0f);
-    }
-    // only need kStateOut grad output_states is true
-    void * dhy_ptr = NULL;
-    if (param_.state_outputs)
-      dhy_ptr = out_grad[rnn_enum::kStateOut].get<gpu, 3, DType>(s).dptr_;
-
-    // Deal with lstm
-    void * dcx_ptr = NULL;
-    void * dcy_ptr = NULL;
-    void * cx_ptr = NULL;
-
-    if (param_.mode == rnn_enum::kLstm) {
-      CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported 
for state cell";
-      cx_ptr = (in_data[rnn_enum::kStateCell].get<gpu, 3, DType>(s)).dptr_;
-      dcx_ptr = (in_grad[rnn_enum::kStateCell].get<gpu, 3, DType>(s)).dptr_;
-    }
-    if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs)
-        dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<gpu, 3, 
DType>(s)).dptr_;
-
-    CHECK_EQ(x.CheckContiguous(), true);
-    CHECK_EQ(w.CheckContiguous(), true);
-    CHECK_EQ(dw.CheckContiguous(), true);
-    CHECK_EQ(hx.CheckContiguous(), true);
-    CHECK_EQ(dhx.CheckContiguous(), true);
-    CHECK_EQ(y.CheckContiguous(), true);
-    CHECK_EQ(dy.CheckContiguous(), true);
-
-    if (!init_cudnn_) {
-      Init(s, in_data, out_data);
-    }
-
-    // Get temp space
-    int temp_size = workspace_size_;
-    Tensor<gpu, 1, DType> temp_space =
-      ctx.requested[rnn_enum::kTempSpace].get_space_typed<gpu, 1, DType>(
-                              mshadow::Shape1(temp_size), s);
-    #if USE_CUDNN_LSTM_PROJ
-    CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
-                                      rnn_desc_,
-                                      y_data_desc_,
-                                      y.dptr_,
-                                      dy_data_desc_,
-                                      dy.dptr_,
-                                      nullptr,
-                                      nullptr,
-                                      dhy_desc_,
-                                      dhy_ptr,
-                                      dcy_desc_,
-                                      dcy_ptr,
-                                      w_desc_,
-                                      w.dptr_,
-                                      hx_desc_,
-                                      hx.dptr_,
-                                      cx_desc_,
-                                      cx_ptr,
-                                      dx_data_desc_,
-                                      dx.dptr_,
-                                      dhx_desc_,
-                                      dhx.dptr_,
-                                      dcx_desc_,
-                                      dcx_ptr,
-                                      nullptr,
-                                      nullptr,
-                                      temp_space.dptr_,
-                                      workspace_byte_,
-                                      reserve_space_.dptr,
-                                      reserve_space_byte_));
-    CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_,
-                                         rnn_desc_,
-                                         x_data_desc_,
-                                         x.dptr_,
-                                         hx_desc_,
-                                         hx.dptr_,
-                                         y_data_desc_,
-                                         y.dptr_,
-                                         temp_space.dptr_,
-                                         workspace_byte_,
-                                         dw_desc_,
-                                         dw.dptr_,
-                                         reserve_space_.dptr,
-                                         reserve_space_byte_));
-    #else
-    CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
-                                    rnn_desc_,
-                                    param_.seq_length_,
-                                    y_desc_vec_.data(),
-                                    y.dptr_,
-                                    dy_desc_vec_.data(),
-                                    dy.dptr_,
-                                    dhy_desc_,
-                                    dhy_ptr,
-                                    dcy_desc_,
-                                    dcy_ptr,
-                                    w_desc_,
-                                    w.dptr_,
-                                    hx_desc_,
-                                    hx.dptr_,
-                                    cx_desc_,
-                                    cx_ptr,
-                                    dx_desc_vec_.data(),
-                                    dx.dptr_,
-                                    dhx_desc_,
-                                    dhx.dptr_,
-                                    dcx_desc_,
-                                    dcx_ptr,
-                                    temp_space.dptr_,
-                                    workspace_byte_,
-                                    reserve_space_.dptr,
-                                    reserve_space_byte_));
-    CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_,
-                                       rnn_desc_,
-                                       param_.seq_length_,
-                                       x_desc_vec_.data(),
-                                       x.dptr_,
-                                       hx_desc_,
-                                       hx.dptr_,
-                                       y_desc_vec_.data(),
-                                       y.dptr_,
-                                       temp_space.dptr_,
-                                       workspace_byte_,
-                                       dw_desc_,
-                                       dw.dptr_,
-                                       reserve_space_.dptr,
-                                       reserve_space_byte_));
-    #endif
-  }
-
- private:
-  inline void Init(mshadow::Stream<gpu> *s,
-                   const std::vector<TBlob> &in_data,
-                   const std::vector<TBlob> &out_data) {
-    using namespace mshadow;
-    #if CUDNN_MAJOR >= 5
-    format_ = CUDNN_TENSOR_NCHW;
-    #endif
-    size_t in_expected = param_.lstm_q_ ? 4 : 3;
-    size_t out_expected = param_.lstm_q_ ? 3 : 2;
-    if (!param_.state_outputs)
-      out_expected = 1;
-
-    CHECK_EQ(in_data.size(), in_expected);
-    CHECK_EQ(out_data.size(), out_expected);
-    if (!init_cudnn_) {
-      init_cudnn_ = true;
-      // get input + output tensors
-      Tensor<gpu, 3, DType> x = in_data[rnn_enum::kData].get<gpu, 3, DType>(s);
-      Tensor<gpu, 1, DType> w = in_data[rnn_enum::kParams].get<gpu, 1, 
DType>(s);
-      param_.seq_length_ = x.shape_[0];
-      param_.batch_size_ = x.shape_[1];
-      param_.input_size_ = x.shape_[2];
-
-      // Tensor Descriptors
-      std::vector<cudnnTensorDescriptor_t> x_vec(param_.seq_length_);
-      std::vector<cudnnTensorDescriptor_t> y_vec(param_.seq_length_);
-      std::vector<cudnnTensorDescriptor_t> dx_vec(param_.seq_length_);
-      std::vector<cudnnTensorDescriptor_t> dy_vec(param_.seq_length_);
-      int dimA[3];
-      int strideA[3];
-      for (int i = 0; i < param_.seq_length_; i++) {
-        CUDNN_CALL(cudnnCreateTensorDescriptor(&x_vec[i]));
-        CUDNN_CALL(cudnnCreateTensorDescriptor(&y_vec[i]));
-        CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_vec[i]));
-        CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_vec[i]));
-
-        dimA[0] = param_.batch_size_;
-        dimA[1] = param_.input_size_;
-        dimA[2] = 1;
-        strideA[0] = dimA[2] * dimA[1];
-        strideA[1] = dimA[2];
-        strideA[2] = 1;
-
-        CUDNN_CALL(cudnnSetTensorNdDescriptor(x_vec[i],
-                                              dtype_,
-                                              3,
-                                              dimA,
-                                              strideA));
-        CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_vec[i],
-                                              dtype_,
-                                              3,
-                                              dimA,
-                                              strideA));
-        dimA[0] = param_.batch_size_;
-        dimA[1] = param_.bidirectional ? param_.state_size * 2 : 
param_.state_size;
-        dimA[2] = 1;
-        strideA[0] = dimA[2] * dimA[1];
-        strideA[1] = dimA[2];
-        strideA[2] = 1;
-
-        CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i],
-                                              dtype_,
-                                              3,
-                                              dimA,
-                                              strideA));
-        CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i],
-                                              dtype_,
-                                              3,
-                                              dimA,
-                                              strideA));
-      }
-      x_desc_vec_ = x_vec;
-      y_desc_vec_ = y_vec;
-      dx_desc_vec_ = dx_vec;
-      dy_desc_vec_ = dy_vec;
-
-      // set the state tensors
-      dimA[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
-      dimA[1] = param_.batch_size_;
-      dimA[2] = param_.state_size;
-      strideA[0] = dimA[2] * dimA[1];
-      strideA[1] = dimA[2];
-      strideA[2] = 1;
-      #if USE_CUDNN_LSTM_PROJ
-      int dimB[3];
-      int strideB[3];
-      dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
-      dimB[1] = param_.batch_size_;
-      dimB[2] = param_.projection_size.has_value() ?
-                param_.projection_size.value() : param_.state_size;
-      strideB[0] = dimB[2] * dimB[1];
-      strideB[1] = dimB[2];
-      strideB[2] = 1;
-      #endif
-
-      #if USE_CUDNN_LSTM_PROJ
-      CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
-                                            dtype_,
-                                            3,
-                                            dimB,
-                                            strideB));
-      #else
-      CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
-                                            dtype_,
-                                            3,
-                                            dimA,
-                                            strideA));
-      #endif
-      CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_,
-                                            dtype_,
-                                            3,
-                                            dimA,
-                                            strideA));
-      #if USE_CUDNN_LSTM_PROJ
-      CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
-                                            dtype_,
-                                            3,
-                                            dimB,
-                                            strideB));
-      #else
-      CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
-                                            dtype_,
-                                            3,
-                                            dimA,
-                                            strideA));
-      #endif
-      CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_,
-                                            dtype_,
-                                            3,
-                                            dimA,
-                                            strideA));
-      #if USE_CUDNN_LSTM_PROJ
-      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
-                                            dtype_,
-                                            3,
-                                            dimB,
-                                            strideB));
-      #else
-      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
-                                            dtype_,
-                                            3,
-                                            dimA,
-                                            strideA));
-      #endif
-      CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_,
-                                            dtype_,
-                                            3,
-                                            dimA,
-                                            strideA));
-      #if USE_CUDNN_LSTM_PROJ
-      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
-                                            dtype_,
-                                            3,
-                                            dimB,
-                                            strideB));
-      #else
-      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
-                                            dtype_,
-                                            3,
-                                            dimA,
-                                            strideA));
-      #endif
-      CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_,
-                                            dtype_,
-                                            3,
-                                            dimA,
-                                            strideA));
-
-      // Create Dropout descriptors
-      if (param_.p > 0) {
-        CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_byte_));
-        dropout_size_ = dropout_byte_ / sizeof(DType);
-        dropout_states_ = Storage::Get()->Alloc(dropout_byte_, 
Context::GPU(s->dev_id));
-      } else {
-        dropout_states_ = {};
-        dropout_byte_ = 0;
-      }
-      CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_,
-                                           param_.p,  // discard probability
-                                           dropout_states_.dptr, dropout_byte_,
-                                           seed_));
-      // RNN descriptors
-      #if CUDNN_MAJOR >= 6
-      cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
-      CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_,
-                                          rnn_desc_,
-                                          param_.state_size,
-                                          param_.num_layers,
-                                          dropout_desc_,
-                                          input_mode_,
-                                          direction_,
-                                          mode_,
-                                          rnn_algo,
-                                          dtype_));
-      #else
-      CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_,
-                                       param_.state_size,
-                                       param_.num_layers,
-                                       dropout_desc_,
-                                       input_mode_,
-                                       direction_,
-                                       mode_,
-                                       dtype_));
-      #endif
-      #if CUDNN_MAJOR >= 7
-        cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
-        if (cudnn_tensor_core_ && rnn_algo == CUDNN_RNN_ALGO_STANDARD) {
-          math_type = CUDNN_TENSOR_OP_MATH;
-        }
-      #if CUDNN_VERSION >= 7200
-            if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() &&
-                (DataType<DType>::kFlag != kFloat16))
-              math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION;
-      #endif
-        CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type));
-      #endif
-      #if USE_CUDNN_LSTM_PROJ
-      if (param_.projection_size.has_value()) {
-        CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_,
-                                               rnn_desc_,
-                                               param_.projection_size.value(),
-                                               0));
-      }
-      #endif
-      // Get temp space sizes
-      CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_,
-                                          rnn_desc_,
-                                          param_.seq_length_,
-                                          x_desc_vec_.data(),
-                                          &workspace_byte_));
-      CUDNN_CALL(cudnnGetRNNTrainingReserveSize(s->dnn_handle_,
-                                                rnn_desc_,
-                                                param_.seq_length_,
-                                                x_desc_vec_.data(),
-                                                &reserve_space_byte_));
-      workspace_size_ = workspace_byte_ / sizeof(DType);
-      // Allocate the reserve space
-      reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, 
Context::GPU(s->dev_id));
-
-      // Check that number of params are correct
-      size_t cudnn_param_size;
-      CUDNN_CALL(cudnnGetRNNParamsSize(s->dnn_handle_,
-                                       rnn_desc_,
-                                       x_desc_vec_[0],
-                                       &cudnn_param_size,
-                                       dtype_));
-      CHECK_EQ(w.shape_[0] * sizeof(DType), cudnn_param_size);
-
-      // Set param descriptors
-      int dim_w[3] = {1, 1, 1};
-      dim_w[0] = w.shape_[0];
-      CUDNN_CALL(cudnnSetFilterNdDescriptor(w_desc_,
-                                            dtype_,
-                                            format_,
-                                            3,
-                                            dim_w));
-      CUDNN_CALL(cudnnSetFilterNdDescriptor(dw_desc_,
-                                            dtype_,
-                                            format_,
-                                            3,
-                                            dim_w));
-
-      // Query weight layout
-      // cudnnFilterDescriptor_t m_desc;
-      // CHECK_EQ(cudnnCreateFilterDescriptor(&m_desc), CUDNN_STATUS_SUCCESS);
-      // DType *p;
-      // int n = 2;
-      // int64_t last = 0;
-      // if (param_.mode == rnn_enum::kLstm) n = 8;
-      // else if (param_.mode == rnn_enum::kGru) n = 6;
-
-      // for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); 
++i) {
-      //   for (int j = 0; j < n; ++j) {
-      //     CHECK_EQ(cudnnGetRNNLinLayerMatrixParams(s->dnn_handle_, 
rnn_desc_,
-      //       i, x_desc_vec_[0], w_desc_, 0, j, m_desc, (void**)&p), 
CUDNN_STATUS_SUCCESS);
-      //     LOG(INFO) << ((int64_t)(p - NULL))/sizeof(DType) - last;
-      //     last = ((int64_t)(p - NULL))/sizeof(DType);
-      //     cudnnDataType_t t;
-      //     cudnnTensorFormat_t f;
-      //     int ndim = 5;
-      //     int dims[5] = {0, 0, 0, 0, 0};
-      //     CHECK_EQ(cudnnGetFilterNdDescriptor(m_desc, ndim, &t, &f, &ndim, 
&dims[0]),
-      //       CUDNN_STATUS_SUCCESS);
-      //     LOG(INFO) << "w: " <<  i << " " << j << " " << ((int64_t)(p - 
NULL))/sizeof(DType);
-      //     for (int i = 0; i < ndim; ++i) LOG(INFO) << dims[i];
-      //   }
-      // }
-
-      // for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); 
++i) {
-      //   for (int j = 0; j < n; ++j) {
-      //     CHECK_EQ(cudnnGetRNNLinLayerBiasParams(s->dnn_handle_, rnn_desc_, 
i, x_desc_vec_[0],
-      //       w_desc_, 0, j, m_desc, (void**)&p), CUDNN_STATUS_SUCCESS);
-      //     LOG(INFO) << ((int64_t)(p - NULL))/sizeof(DType) - last;
-      //     last = ((int64_t)(p - NULL))/sizeof(DType);
-      //     LOG(INFO) << "b: " << i << " " << j << " " << ((int64_t)(p - 
NULL))/sizeof(DType);
-      //   }
-      // }
-    }
-  }
-
-  cudnnDataType_t dtype_;
-  bool init_cudnn_;
-  cudnnRNNDescriptor_t rnn_desc_;
-  cudnnRNNMode_t mode_;
-  cudnnDirectionMode_t direction_;
-  cudnnRNNInputMode_t input_mode_;
-  cudnnDropoutDescriptor_t dropout_desc_;
-  Storage::Handle dropout_states_, reserve_space_;
-  uint64_t seed_ = 17 + rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
-  size_t workspace_byte_, reserve_space_byte_, dropout_byte_;
-  int workspace_size_, dropout_size_;
-  std::vector<cudnnTensorDescriptor_t> x_desc_vec_, y_desc_vec_, dx_desc_vec_, 
dy_desc_vec_;
-  #if USE_CUDNN_LSTM_PROJ
-  cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, 
dy_data_desc_;
-  #endif
-  cudnnTensorDescriptor_t hx_desc_, cx_desc_;
-  cudnnTensorDescriptor_t hy_desc_, cy_desc_;
-  cudnnTensorDescriptor_t dhx_desc_, dcx_desc_;
-  cudnnTensorDescriptor_t dhy_desc_, dcy_desc_;
-
-  cudnnFilterDescriptor_t w_desc_, dw_desc_;
-  // Allow TensorCore algo policy
-  bool cudnn_tensor_core_;
-
-  #if CUDNN_MAJOR >= 5
-  cudnnTensorFormat_t format_;
-  #endif
-  RNNParam param_;
-};
-#endif  // __CUDACC__ && CUDNN
-}  // namespace op
-}  // namespace mxnet
-
-#endif  // MXNET_OPERATOR_CUDNN_RNN_INL_H_
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 71ad331..37f21ce 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -26,6 +26,9 @@
 #ifndef MXNET_OPERATOR_RNN_INL_H_
 #define MXNET_OPERATOR_RNN_INL_H_
 
+#define MXNET_USE_CUDNN_RNN MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
+#define USE_CUDNN_LSTM_PROJ MXNET_USE_CUDNN == 1 && CUDNN_VERSION >= 7200
+
 #include <dmlc/logging.h>
 #include <dmlc/parameter.h>
 #include <mxnet/operator.h>
@@ -35,6 +38,7 @@
 #include <vector>
 #include <string>
 #include <utility>
+#include <cstdint>
 #include "./math.h"
 #include "./math_functions-inl.h"
 #include "./operator_common.h"
@@ -47,7 +51,7 @@ namespace rnn_enum {
   enum RNNOpInputs {kData, kParams, kState, kStateCell};
   enum RNNOpOutputs {kOut, kStateOut, kStateCellOut};
   enum RNNModeType {kRnnRelu, kRnnTanh, kLstm, kGru};
-  enum RNNOpResource {kTempSpace};
+  enum RNNOpResource {kCuDNNDropoutDescSpace};
 }
 
 inline int GetRnnParamSize(int num_layer,
@@ -160,9 +164,8 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
   uint32_t num_layers;
   bool bidirectional, state_outputs;
   int mode;
-  float p, pkeep_;
+  float p;
   int seq_length_, batch_size_, input_size_;
-  bool lstm_q_;  // whether type is lstm
   dmlc::optional<int> projection_size;
   dmlc::optional<double> lstm_state_clip_min, lstm_state_clip_max;
   bool lstm_state_clip_nan;
@@ -212,7 +215,6 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
   }
 };
 
-
 /**
  * @params: ws: Temp workspace for gemm's output storage.
  *          rs: Reserve space of forward intermediate data used for training.
@@ -236,6 +238,7 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {
  *                  hy's shape is [num_layers, batch_size, state_size]
  *          cy_ptr: Only used in lstm mode. pointer of tensor cy  containing 
the cell state
  *                  for t=seq_length. cy' shape is [num_layers, batch_size, 
state_size]
+ *          dropout: should be 0 <= dropout < 1
  *          mode: Specifies the type of RNN to compute.
  */
 template <typename DType>
@@ -376,58 +379,189 @@ void RNNBackward(DType* ws,
   }
 }
 
-template<typename DType>
-class RNNOp : public Operator{
+template<typename xpu, typename DType>
+class RNNOp {
  public:
-  explicit RNNOp(RNNParam p)
-    :param_(p), init_space_(false), reserve_space_size_(0) {
+  RNNParam param_;
+  Context ctx_;
+  explicit RNNOp(RNNParam param, Context ctx) {
+    this->param_ = param;
+    this->ctx_ = ctx;
+    #if MXNET_USE_CUDNN_RNN
+    init_cudnn_ = false;
+    dtype_ = mshadow::DataType<DType>::kCudnnFlag;
+    // TensorCore algos only allowed on fp16-I/O convolutions if permitted by 
the global policy.
+    // No tests in place for fp16 RNNs, so leave TensorCore disabled for now.
+    cudnn_tensor_core_ = false;
+    // When fp16 RNN tests are introduced, we can enable TensorCore as follows:
+//    cudnn_tensor_core =
+//        mshadow::DataType<DType>::kFlag == mshadow::kFloat16 && 
GetEnvAllowTensorCore();
+    // Defaults
+    input_mode_ = CUDNN_LINEAR_INPUT;  // Don't support this yet
+    // RNN Mode
+    switch (param_.mode) {
+      case rnn_enum::kRnnRelu:
+        mode_ = CUDNN_RNN_RELU;
+        break;
+      case rnn_enum::kRnnTanh:
+        mode_ = CUDNN_RNN_TANH;
+        break;
+      case rnn_enum::kLstm:
+        mode_ = CUDNN_LSTM;
+        break;
+      case rnn_enum::kGru:
+        mode_ = CUDNN_GRU;
+        break;
+      default:
+        LOG(FATAL) << "Not implmented";
+    }
+#if USE_CUDNN_LSTM_PROJ
     if (param_.projection_size.has_value()) {
-      LOG(FATAL) << "hidden layer projection is only supported for GPU with 
CuDNN later than 7.1.1";
+      CHECK_EQ(param_.mode, rnn_enum::kLstm)
+        << "Projection is only supported for LSTM.";
+      CHECK_GE(param_.state_size, param_.projection_size.value())
+        << "State size must be larger than projection size.";
     }
+#else
+    CHECK(!param_.projection_size.has_value())
+      << "Projection is only supported for LSTM with CuDNN version later than 
7.1.1.";
+#endif
+#if USE_CUDNN_LSTM_PROJ
     if (param_.lstm_state_clip_min.has_value()
         || param_.lstm_state_clip_max.has_value()) {
-      LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN 
later than 7.2.1";
+      CHECK_EQ(param_.mode, rnn_enum::kLstm)
+        << "State clipping is only supported for LSTM.";
+      CHECK(param_.lstm_state_clip_min.has_value() && 
param_.lstm_state_clip_max.has_value())
+        << "lstm_state_clip_min and lstm_state_clip_max must be specified 
together.";
+      CHECK_GE(param_.lstm_state_clip_max.value(), 
param_.lstm_state_clip_min.value())
+        << "lstm_state_clip_max must be greater or equal to 
lstm_state_clip_min";
+    }
+#else
+    CHECK(!param_.lstm_state_clip_min.has_value()
+          && !param_.lstm_state_clip_max.has_value())
+      << "State clipping is only supported for LSTM with CuDNN version later 
than 7.2.1.";
+#endif
+    // RNN Direction
+    direction_ = param_.bidirectional ? CUDNN_BIDIRECTIONAL : 
CUDNN_UNIDIRECTIONAL;
+    // Create descriptors
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_));
+    CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_));
+
+    CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_));
+    CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_));
+
+    CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
+    CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
+
+    #if USE_CUDNN_LSTM_PROJ
+    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&x_data_desc_));
+    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&y_data_desc_));
+    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dx_data_desc_));
+    CUDNN_CALL(cudnnCreateRNNDataDescriptor(&dy_data_desc_));
+    #endif
+    #else
+    if (ctx_.dev_type == kGPU) {
+      LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment.";
+    }
+    #endif
+
+    if (ctx_.dev_type == kCPU) {
+      this->init_space_ = false;
+      this->temp_init_space_ = false;
+      this->reserve_cpu_space_size_ = 0;
+      this->temp_cpu_space_size_ = 0;
+
+      if (param_.projection_size.has_value()) {
+        LOG(FATAL) <<
+            "hidden layer projection is only supported for GPU with CuDNN 
later than 7.1.1";
+      }
+      if (param_.lstm_state_clip_min.has_value()
+          || param_.lstm_state_clip_max.has_value()) {
+        LOG(FATAL) << "LSTM state clipping is only supported for GPU with 
CuDNN later than 7.2.1";
+      }
     }
   }
 
   ~RNNOp() {
-    if (init_space_) {
+    #if MXNET_USE_CUDNN_RNN
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(hy_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(cy_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(dhx_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(dcx_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(dhy_desc_));
+    CUDNN_CALL(cudnnDestroyTensorDescriptor(dcy_desc_));
+
+    CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc_));
+    CUDNN_CALL(cudnnDestroyFilterDescriptor(dw_desc_));
+    CUDNN_CALL(cudnnDestroyRNNDescriptor(rnn_desc_));
+    CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));
+
+    if (init_cudnn_) {
+      for (size_t i = 0; i < x_desc_vec_.size(); ++i) {
+        CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_vec_[i]));
+        CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_vec_[i]));
+        CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_vec_[i]));
+        CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_vec_[i]));
+      }
+      init_cudnn_ = false;
+      Storage::Get()->Free(temp_space_);
       Storage::Get()->Free(reserve_space_);
-      init_space_ = false;
+    }
+    #if USE_CUDNN_LSTM_PROJ
+    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(x_data_desc_));
+    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(y_data_desc_));
+    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dx_data_desc_));
+    CUDNN_CALL(cudnnDestroyRNNDataDescriptor(dy_data_desc_));
+    #endif
+    #endif
+
+    if (ctx_.dev_type == kCPU) {
+      if (init_space_) {
+        Storage::Get()->Free(reserve_cpu_space_);
+        init_space_ = false;
+      }
+      if (temp_init_space_) {
+        Storage::Get()->Free(temp_cpu_space_);
+        temp_init_space_ = false;
+      }
     }
   }
 
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
+  void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
+               const std::vector<OpReqType> &req,
+               const std::vector<TBlob> &out_data) {
     using namespace mshadow;
     using namespace mshadow::expr;
     CHECK(param_.p >= 0.0f && param_.p < 1.0f)
         << "unsupported dropout value, should be 0 <= dropout < 1";
-
-    size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
-    size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
-    if (!param_.state_outputs) {
-      out_expected = 1;
+    size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+    //  kOut
+    size_t num_outputs = 1;
+    if (param_.state_outputs) {
+      // kOut, kStateOut, kStateCellOut
+      num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
     }
-    CHECK_EQ(in_data.size(), in_expected);
-    CHECK_EQ(out_data.size(), out_expected);
-    Stream<cpu> *s = ctx.get_stream<cpu>();
-    // get input + output tensor
-    Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s);
-    Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
-    Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s);
-    Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);
-    CHECK(x.CheckContiguous());
-    CHECK(w.CheckContiguous());
-    CHECK(hx.CheckContiguous());
-    CHECK(y.CheckContiguous());
+
+    CHECK_EQ(in_data.size(), num_inputs);
+    CHECK_EQ(out_data.size(), num_outputs);
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    // get input + output tensors
+    Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
+    Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
+    Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s);
+    Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s);
+
     param_.seq_length_ = x.shape_[0];
     param_.batch_size_ = x.shape_[1];
     param_.input_size_ = x.shape_[2];
-
     const int direction = param_.bidirectional ? 2 : 1;
     const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, 
direction, param_.mode);
     DType* b_ptr = w.dptr_ + w.shape_[0] - bsize;
@@ -438,124 +572,308 @@ class RNNOp : public Operator{
     }
     DType* cx_ptr = NULL;
     DType* cy_ptr = NULL;
-
-    if (param_.mode == rnn_enum::kLstm) {
-      cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
-      if (param_.state_outputs) {
-        cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>();
-      }
+    if (param_.mode == rnn_enum::kLstm)
+      cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+    if (param_.mode == rnn_enum::kLstm && param_.state_outputs)
+      cy_ptr = (out_data[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
+
+    CHECK_EQ(x.CheckContiguous(), true);
+    CHECK_EQ(w.CheckContiguous(), true);
+    CHECK_EQ(hx.CheckContiguous(), true);
+    CHECK_EQ(y.CheckContiguous(), true);
+
+    #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+    if (!init_cudnn_) {
+      Init(ctx, s, in_data, out_data);
     }
 
-    // allocate temp space
-    const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, 
param_.batch_size_,
-                                                      param_.state_size, 
direction, param_.mode);
-    Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
-        .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
+    #if USE_CUDNN_LSTM_PROJ
+    std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_);
+    CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
+                                         dtype_,
+                                         
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                         param_.seq_length_,
+                                         param_.batch_size_,
+                                         param_.input_size_,
+                                         seqLengthArray.data(),
+                                         nullptr));
+    int out_size =
+      (param_.projection_size.has_value()) ? param_.projection_size.value() : 
param_.state_size;
+    out_size = (param_.bidirectional) ? (out_size * 2) : out_size;
+    CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_,
+                                         dtype_,
+                                         
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                         param_.seq_length_,
+                                         param_.batch_size_,
+                                         out_size,
+                                         seqLengthArray.data(),
+                                         nullptr));
+    if (ctx.is_train) {
+      CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_,
+                                           dtype_,
+                                           
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           param_.input_size_,
+                                           seqLengthArray.data(),
+                                           nullptr));
+      CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_,
+                                           dtype_,
+                                           
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+                                           param_.seq_length_,
+                                           param_.batch_size_,
+                                           out_size,
+                                           seqLengthArray.data(),
+                                           nullptr));
+    }
+    #endif
+
+    #if USE_CUDNN_LSTM_PROJ
+    bool clip_state = param_.lstm_state_clip_min.has_value();
+    bool clip_nan = param_.lstm_state_clip_nan;
+    CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_,
+                               rnn_desc_,
+                               clip_state ? CUDNN_RNN_CLIP_MINMAX : 
CUDNN_RNN_CLIP_NONE,
+                               clip_nan ? CUDNN_NOT_PROPAGATE_NAN : 
CUDNN_PROPAGATE_NAN,
+                               clip_state ? param_.lstm_state_clip_min.value() 
: 0.0,
+                               clip_state ? param_.lstm_state_clip_max.value() 
: 0.0));
+    #endif
 
     if (ctx.is_train) {
-      const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, 
direction,
-                                                   param_.seq_length_, 
param_.batch_size_,
-                                                   param_.state_size, 
param_.mode);
-      if (init_space_ && reserve_space_size_ < r_size) {
-        Storage::Get()->Free(reserve_space_);
-        init_space_ = false;
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
+                                           rnn_desc_,
+                                           x_data_desc_,
+                                           x.dptr_,
+                                           hx_desc_,
+                                           hx.dptr_,
+                                           cx_desc_,
+                                           cx_ptr,
+                                           w_desc_,
+                                           w.dptr_,
+                                           y_data_desc_,
+                                           y.dptr_,
+                                           hy_desc_,
+                                           hy_ptr,
+                                           cy_desc_,
+                                           cy_ptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           nullptr,
+                                           temp_space_.dptr,
+                                           workspace_byte_,
+                                           reserve_space_.dptr,
+                                           reserve_space_byte_));
+      #else
+      CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
+                                         rnn_desc_,
+                                         param_.seq_length_,
+                                         x_desc_vec_.data(),
+                                         x.dptr_,
+                                         hx_desc_,
+                                         hx.dptr_,
+                                         cx_desc_,
+                                         cx_ptr,
+                                         w_desc_,
+                                         w.dptr_,
+                                         y_desc_vec_.data(),
+                                         y.dptr_,
+                                         hy_desc_,
+                                         hy_ptr,
+                                         cy_desc_,
+                                         cy_ptr,
+                                         temp_space_.dptr,
+                                         workspace_byte_,
+                                         reserve_space_.dptr,
+                                         reserve_space_byte_));
+      #endif
+    } else {
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
+                                            rnn_desc_,
+                                            x_data_desc_,
+                                            x.dptr_,
+                                            hx_desc_,
+                                            hx.dptr_,
+                                            cx_desc_,
+                                            cx_ptr,
+                                            w_desc_,
+                                            w.dptr_,
+                                            y_data_desc_,
+                                            y.dptr_,
+                                            hy_desc_,
+                                            hy_ptr,
+                                            cy_desc_,
+                                            cy_ptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            nullptr,
+                                            temp_space_.dptr,
+                                            workspace_byte_));
+      #else
+      CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
+                                          rnn_desc_,
+                                          param_.seq_length_,
+                                          x_desc_vec_.data(),
+                                          x.dptr_,
+                                          hx_desc_,
+                                          hx.dptr_,
+                                          cx_desc_,
+                                          cx_ptr,
+                                          w_desc_,
+                                          w.dptr_,
+                                          y_desc_vec_.data(),
+                                          y.dptr_,
+                                          hy_desc_,
+                                          hy_ptr,
+                                          cy_desc_,
+                                          cy_ptr,
+                                          temp_space_.dptr,
+                                          workspace_byte_));
+      #endif
+    }
+    #endif
+
+    if (ctx_.dev_type == kCPU) {
+      // allocate temp space
+      const size_t work_cpu_space_size =
+          GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+                              param_.state_size, direction, param_.mode);
+      if (temp_init_space_ && temp_cpu_space_size_ < work_cpu_space_size) {
+          Storage::Get()->Free(temp_cpu_space_);
+          temp_init_space_ = false;
       }
-
-      if (!init_space_) {
-        reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), 
Context::CPU());
-        reserve_space_size_ = r_size;
-        init_space_ = true;
+      if (!temp_init_space_) {
+        temp_cpu_space_ = Storage::Get()->Alloc
+            (work_cpu_space_size * sizeof(DType), Context::CPU());
+        temp_cpu_space_size_ = work_cpu_space_size;
+        temp_init_space_ = true;
+      }
+      DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr);
+      if (ctx.is_train) {
+        const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, 
direction,
+                                                     param_.seq_length_, 
param_.batch_size_,
+                                                     param_.state_size, 
param_.mode);
+        if (init_space_ && reserve_cpu_space_size_ < r_size) {
+          Storage::Get()->Free(reserve_cpu_space_);
+          init_space_ = false;
+        }
+        if (!init_space_) {
+          reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType), 
Context::CPU());
+          reserve_cpu_space_size_ = r_size;
+          init_space_ = true;
+        }
+
+        DType* reserve_space_ptr = 
static_cast<DType*>(reserve_cpu_space_.dptr);
+
+        RNNForwardTraining<DType>(work_cpu_space,
+                                  reserve_space_ptr,
+                                  param_.state_outputs,
+                                  param_.num_layers,
+                                  direction,
+                                  param_.seq_length_,
+                                  param_.batch_size_,
+                                  param_.input_size_,
+                                  param_.state_size,
+                                  x.dptr_,
+                                  hx.dptr_,
+                                  cx_ptr,
+                                  w.dptr_,
+                                  b_ptr,
+                                  y.dptr_,
+                                  hy_ptr,
+                                  cy_ptr,
+                                  param_.p,
+                                  param_.mode);
+      } else {
+        RNNForwardInference<DType>(work_cpu_space,
+                                   param_.state_outputs,
+                                   param_.num_layers,
+                                   direction,
+                                   param_.seq_length_,
+                                   param_.batch_size_,
+                                   param_.input_size_,
+                                   param_.state_size,
+                                   x.dptr_,
+                                   hx.dptr_,
+                                   cx_ptr,
+                                   w.dptr_,
+                                   b_ptr,
+                                   y.dptr_,
+                                   hy_ptr,
+                                   cy_ptr,
+                                   param_.mode);
       }
-
-      DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
-      RNNForwardTraining<DType>(workspace.dptr_,
-                                reserve_space_ptr,
-                                param_.state_outputs,
-                                param_.num_layers,
-                                direction,
-                                param_.seq_length_,
-                                param_.batch_size_,
-                                param_.input_size_,
-                                param_.state_size,
-                                x.dptr_,
-                                hx.dptr_,
-                                cx_ptr,
-                                w.dptr_,
-                                b_ptr,
-                                y.dptr_,
-                                hy_ptr,
-                                cy_ptr,
-                                param_.p,
-                                param_.mode);
-    } else {
-      RNNForwardInference<DType>(workspace.dptr_,
-                                 param_.state_outputs,
-                                 param_.num_layers,
-                                 direction,
-                                 param_.seq_length_,
-                                 param_.batch_size_,
-                                 param_.input_size_,
-                                 param_.state_size,
-                                 x.dptr_,
-                                 hx.dptr_,
-                                 cx_ptr,
-                                 w.dptr_,
-                                 b_ptr,
-                                 y.dptr_,
-                                 hy_ptr,
-                                 cy_ptr,
-                                 param_.mode);
     }
   }
 
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
+  void Backward(const OpContext &ctx,
+                const std::vector<TBlob> &out_grad,
+                const std::vector<TBlob> &in_data,
+                const std::vector<TBlob> &out_data,
+                const std::vector<OpReqType> &req,
+                const std::vector<TBlob> &in_grad) {
     using namespace mshadow;
     using namespace mshadow::expr;
     CHECK(param_.p >= 0.0f && param_.p < 1.0f)
         << "unsupported dropout value, should be 0 <= dropout < 1";
 
-    size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
-    size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
-    if (!param_.state_outputs) {
-      out_expected = 1;
+    size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+    //  kOut
+    size_t num_outputs = 1;
+    if (param_.state_outputs) {
+      // kOut, kStateOut, kStateCellOut
+      num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
     }
-    CHECK_EQ(in_data.size(), in_expected);
-    CHECK_EQ(out_data.size(), out_expected);
-    CHECK_EQ(in_grad.size(), in_expected);
-    CHECK_EQ(out_grad.size(), out_expected);
-    CHECK_EQ(req.size(), in_expected);
+
+    CHECK_EQ(in_data.size(), num_inputs);
+    CHECK_EQ(out_data.size(), num_outputs);
+    CHECK_EQ(in_grad.size(), num_inputs);
+    CHECK_EQ(out_grad.size(), num_outputs);
+    CHECK_EQ(req.size(), num_inputs);
     CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for 
data";
     CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for 
state";
-    mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+    Stream<xpu> *s = ctx.get_stream<xpu>();
     // get input + output tensors
-    Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s);
-    Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
-    Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s);
-    Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);
-    Tensor<cpu, 3, DType> dx = in_grad[rnn_enum::kData].get<cpu, 3, DType>(s);
-    Tensor<cpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<cpu, 1, 
DType>(s);
-    Tensor<cpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<cpu, 3, 
DType>(s);
-    Tensor<cpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<cpu, 3, DType>(s);
-    CHECK(x.CheckContiguous());
-    CHECK(w.CheckContiguous());
-    CHECK(hx.CheckContiguous());
-    CHECK(y.CheckContiguous());
-    CHECK(dx.CheckContiguous());
-    CHECK(dw.CheckContiguous());
-    CHECK(dhx.CheckContiguous());
-    CHECK(dy.CheckContiguous());
+    Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
+    Tensor<xpu, 3, DType> dx = in_grad[rnn_enum::kData].get<xpu, 3, DType>(s);
+    Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
+    Tensor<xpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<xpu, 1, 
DType>(s);
+    Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s);
+    Tensor<xpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<xpu, 3, 
DType>(s);
+    Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s);
+    Tensor<xpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<xpu, 3, DType>(s);
+
+    CHECK_EQ(x.CheckContiguous(), true);
+    CHECK_EQ(w.CheckContiguous(), true);
+    CHECK_EQ(dw.CheckContiguous(), true);
+    CHECK_EQ(hx.CheckContiguous(), true);
+    CHECK_EQ(dhx.CheckContiguous(), true);
+    CHECK_EQ(y.CheckContiguous(), true);
+    CHECK_EQ(dy.CheckContiguous(), true);
+    CHECK_EQ(dx.CheckContiguous(), true);
+
+    if (req[rnn_enum::kParams] != kAddTo) {
+      dw = mshadow::expr::ScalarExp<DType>(0.0f);
+    }
+
     param_.seq_length_ = x.shape_[0];
     param_.batch_size_ = x.shape_[1];
     param_.input_size_ = x.shape_[2];
 
     const int direction = param_.bidirectional ? 2 : 1;
     const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size, 
direction, param_.mode);
+
     DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize;
 
     DType * dhy_ptr = NULL;
@@ -563,260 +881,582 @@ class RNNOp : public Operator{
       dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>();
     }
 
-    DType * cx_ptr = NULL;
-    DType * dcx_ptr = NULL;
-    DType * dcy_ptr = NULL;
+    DType* dcx_ptr = NULL;
+    DType* dcy_ptr = NULL;
+    DType* cx_ptr = NULL;
 
     if (param_.mode == rnn_enum::kLstm) {
       CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported 
for state cell";
-      cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
-      dcx_ptr = in_grad[rnn_enum::kStateCell].dptr<DType>();
-      if (param_.state_outputs) {
-        dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr<DType>();
-      }
+      cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+      dcx_ptr = (in_grad[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
     }
+    if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs)
+        dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<xpu, 3, 
DType>(s)).dptr_;
 
-    // allocate temp space
-    const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, 
param_.batch_size_,
-                                                      param_.state_size, 
direction, param_.mode);
-    Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
-        .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
-
-    size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
-                                           param_.seq_length_, 
param_.batch_size_,
-                                           param_.state_size, param_.mode);
-    if (!init_space_ || reserve_space_size_ != r_size) {
-      LOG(FATAL) << "Check forward init error";
+    #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+    if (!init_cudnn_) {
+      Init(ctx, s, in_data, out_data);
     }
 
-    DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
-    RNNBackward<DType>(workspace.dptr_,
-                       reserve_space_ptr,
-                       param_.num_layers,
-                       direction,
-                       param_.seq_length_,
-                       param_.batch_size_,
-                       param_.input_size_,
-                       param_.state_size,
-                       x.dptr_,
-                       hx.dptr_,
-                       cx_ptr,
-                       w.dptr_,
-                       y.dptr_,
-                       dy.dptr_,
-                       dhy_ptr,
-                       dcy_ptr,
-                       dx.dptr_,
-                       dhx.dptr_,
-                       dcx_ptr,
-                       dw.dptr_,
-                       db_ptr,
-                       req[rnn_enum::kData],
-                       req[rnn_enum::kParams],
-                       req[rnn_enum::kState],
-                       // State cell should be present for LSTMs, but is 
absent for other RNNs.
-                       param_.mode == rnn_enum::kLstm ? 
req[rnn_enum::kStateCell] : kNullOp,
-                       param_.p,
-                       param_.mode);
-  }
-
- private:
-  RNNParam param_;
-  bool init_space_;
-  size_t reserve_space_size_;
-  Storage::Handle reserve_space_;
-};  // class RNNOp
+    #if USE_CUDNN_LSTM_PROJ
+    CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
+                                      rnn_desc_,
+                                      y_data_desc_,
+                                      y.dptr_,
+                                      dy_data_desc_,
+                                      dy.dptr_,
+                                      nullptr,
+                                      nullptr,
+                                      dhy_desc_,
+                                      dhy_ptr,
+                                      dcy_desc_,
+                                      dcy_ptr,
+                                      w_desc_,
+                                      w.dptr_,
+                                      hx_desc_,
+                                      hx.dptr_,
+                                      cx_desc_,
+                                      cx_ptr,
+                                      dx_data_desc_,
+                                      dx.dptr_,
+                                      dhx_desc_,
+                                      dhx.dptr_,
+                                      dcx_desc_,
+                                      dcx_ptr,
+                                      nullptr,
+                                      nullptr,
+                                      temp_space_.dptr,
+                                      workspace_byte_,
+                                      reserve_space_.dptr,
+                                      reserve_space_byte_));
+    CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_,
+                                         rnn_desc_,
+                                         x_data_desc_,
+                                         x.dptr_,
+                                         hx_desc_,
+                                         hx.dptr_,
+                                         y_data_desc_,
+                                         y.dptr_,
+                                         temp_space_.dptr,
+                                         workspace_byte_,
+                                         dw_desc_,
+                                         dw.dptr_,
+                                         reserve_space_.dptr,
+                                         reserve_space_byte_));
+    #else
+    CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
+                                    rnn_desc_,
+                                    param_.seq_length_,
+                                    y_desc_vec_.data(),
+                                    y.dptr_,
+                                    dy_desc_vec_.data(),
+                                    dy.dptr_,
+                                    dhy_desc_,
+                                    dhy_ptr,
+                                    dcy_desc_,
+                                    dcy_ptr,
+                                    w_desc_,
+                                    w.dptr_,
+                                    hx_desc_,
+                                    hx.dptr_,
+                                    cx_desc_,
+                                    cx_ptr,
+                                    dx_desc_vec_.data(),
+                                    dx.dptr_,
+                                    dhx_desc_,
+                                    dhx.dptr_,
+                                    dcx_desc_,
+                                    dcx_ptr,
+                                    temp_space_.dptr,
+                                    workspace_byte_,
+                                    reserve_space_.dptr,
+                                    reserve_space_byte_));
+    CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_,
+                                       rnn_desc_,
+                                       param_.seq_length_,
+                                       x_desc_vec_.data(),
+                                       x.dptr_,
+                                       hx_desc_,
+                                       hx.dptr_,
+                                       y_desc_vec_.data(),
+                                       y.dptr_,
+                                       temp_space_.dptr,
+                                       workspace_byte_,
+                                       dw_desc_,
+                                       dw.dptr_,
+                                       reserve_space_.dptr,
+                                       reserve_space_byte_));
+    #endif
+    #endif
+
+    if (ctx_.dev_type == kCPU) {
+      // allocate temp space
+      const size_t work_cpu_space_size =
+          GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+                              param_.state_size, direction, param_.mode);
+      if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) {
+        LOG(FATAL) << "Check temp init error";
+      }
+      DType* work_cpu_space = static_cast<DType*>(temp_cpu_space_.dptr);
+      size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
+                                             param_.seq_length_, 
param_.batch_size_,
+                                             param_.state_size, param_.mode);
 
-template<typename xpu>
-Operator* CreateOp(RNNParam param, int dtype);
+      if (!init_space_ || reserve_cpu_space_size_ != r_size) {
+        LOG(FATAL) << "Check forward init error";
+      }
 
-#if DMLC_USE_CXX11
-class RNNProp : public OperatorProperty {
- public:
-  std::vector<std::string> ListArguments() const override {
-    if (param_.mode == rnn_enum::kLstm) {
-      return {"data", "parameters", "state", "state_cell"};
-    } else {
-      return {"data", "parameters", "state"};
+      DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.dptr);
+      RNNBackward<DType>(work_cpu_space,
+                         reserve_space_ptr,
+                         param_.num_layers,
+                         direction,
+                         param_.seq_length_,
+                         param_.batch_size_,
+                         param_.input_size_,
+                         param_.state_size,
+                         x.dptr_,
+                         hx.dptr_,
+                         cx_ptr,
+                         w.dptr_,
+                         y.dptr_,
+                         dy.dptr_,
+                         dhy_ptr,
+                         dcy_ptr,
+                         dx.dptr_,
+                         dhx.dptr_,
+                         dcx_ptr,
+                         dw.dptr_,
+                         db_ptr,
+                         req[rnn_enum::kData],
+                         req[rnn_enum::kParams],
+                         req[rnn_enum::kState],
+                         // State cell should be present for LSTMs, but is 
absent for other RNNs.
+                         param_.mode == rnn_enum::kLstm ? 
req[rnn_enum::kStateCell] : kNullOp,
+                         param_.p,
+                         param_.mode);
     }
   }
 
-  std::vector<std::string> ListOutputs() const override {
-    std::vector<std::string> outputs = {"output"};
-    if (!param_.state_outputs)
-      return outputs;
-    else
-      outputs.emplace_back("state");
-    if (param_.mode == rnn_enum::kLstm)
-      outputs.emplace_back("state_cell");
-    return outputs;
-  }
-
-  int NumOutputs() const override {
-    int mode_num = (param_.mode == rnn_enum::kLstm) ? 2 : 1;
-    int num_outputs = param_.state_outputs ? (mode_num + 1) : 1;
-    return num_outputs;
-  }
-
-  void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) 
override {
-    param_.Init(kwargs);
-  }
 
-  std::map<std::string, std::string> GetParams() const override {
-    return param_.__DICT__();
-  }
-
-  bool InferShape(mxnet::ShapeVector *in_shape,
-                  mxnet::ShapeVector *out_shape,
-                  mxnet::ShapeVector *aux_shape) const override {
+ private:
+  inline void Init(const OpContext &ctx,
+                   mshadow::Stream<xpu> *s,
+                   const std::vector<TBlob> &in_data,
+                   const std::vector<TBlob> &out_data) {
     using namespace mshadow;
-    if (param_.mode == rnn_enum::kLstm) {
-      CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state, 
cell_state]";
-    } else {
-      CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]";
-    }
-    const mxnet::TShape &dshape = (*in_shape)[rnn_enum::kData];
-    if (dshape.ndim() ==  0) return false;
-    CHECK_EQ(dshape.ndim(), 3U) \
-        << "Input data should be rank-3 tensor of dim [sequence length, batch 
size, input size]";
-    // data: [sequence len, batch, input dimension]
-    int batch_size = dshape[1];
-    int input_size = dshape[2];
-    int numDirections = param_.bidirectional ? 2 : 1;
-    int total_layers = numDirections * param_.num_layers;  // double for 
bidirectional
-    int layer_size = (param_.projection_size.has_value()) ?
-                     param_.projection_size.value() : param_.state_size;
-    SHAPE_ASSIGN_CHECK(*in_shape,
-                       rnn_enum::kState,
-                       Shape3(total_layers, batch_size, layer_size));
-    if (param_.mode == rnn_enum::kLstm)
-      SHAPE_ASSIGN_CHECK(*in_shape,
-                         rnn_enum::kStateCell,
-                         Shape3(total_layers, batch_size, param_.state_size));
-
-    // calculate parameter vector length
-    int param_size = GetRnnParamSize(param_.num_layers,
-                                     input_size,
-                                     param_.state_size,
-                                     numDirections,
-                                     param_.mode,
-                                     param_.projection_size);
-    SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
-
-    out_shape->clear();
-    // output: [sequence len, batch, output size]
-    mxnet::TShape oshape = dshape;
-    if (param_.projection_size.has_value()) {
-      oshape[2] = numDirections * param_.projection_size.value();
-    } else {
-      oshape[2] = numDirections * param_.state_size;
+    size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+    //  kOut
+    size_t num_outputs = 1;
+    if (param_.state_outputs) {
+      // kOut, kStateOut, kStateCellOut
+      num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
     }
-    out_shape->push_back(oshape);
-    if (!param_.state_outputs) {
-      return true;
-    } else {
-      // outStateShape: [layer_num, batch, state size]
-      mxnet::TShape outStateShape = dshape;
-      outStateShape[0] = total_layers;
-      outStateShape[1] = batch_size;
-      if (param_.projection_size.has_value()) {
-        outStateShape[2] = param_.projection_size.value();
-      } else {
-        outStateShape[2] = param_.state_size;
+
+    CHECK_EQ(in_data.size(), num_inputs);
+    CHECK_EQ(out_data.size(), num_outputs);
+
+    #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+    #if CUDNN_MAJOR >= 5
+    format_ = CUDNN_TENSOR_NCHW;
+    #endif
+
+    if (!init_cudnn_) {
+      init_cudnn_ = true;
+      // get input + output tensors
+      Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
+      Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, 
DType>(s);
+      param_.seq_length_ = x.shape_[0];
+      param_.batch_size_ = x.shape_[1];
+      param_.input_size_ = x.shape_[2];
+
+      // Tensor Descriptors
+      std::vector<cudnnTensorDescriptor_t> x_vec(param_.seq_length_);
+      std::vector<cudnnTensorDescriptor_t> y_vec(param_.seq_length_);
+      std::vector<cudnnTensorDescriptor_t> dx_vec(param_.seq_length_);
+      std::vector<cudnnTensorDescriptor_t> dy_vec(param_.seq_length_);
+      int dimA[3];
+      int strideA[3];
+      for (int i = 0; i < param_.seq_length_; i++) {
+        CUDNN_CALL(cudnnCreateTensorDescriptor(&x_vec[i]));
+        CUDNN_CALL(cudnnCreateTensorDescriptor(&y_vec[i]));
+        CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_vec[i]));
+        CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_vec[i]));
+
+        dimA[0] = param_.batch_size_;
+        dimA[1] = param_.input_size_;
+        dimA[2] = 1;
+        strideA[0] = dimA[2] * dimA[1];
+        strideA[1] = dimA[2];
+        strideA[2] = 1;
+
+        CUDNN_CALL(cudnnSetTensorNdDescriptor(x_vec[i],
+                                              dtype_,
+                                              3,
+                                              dimA,
+                                              strideA));
+        CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_vec[i],
+                                              dtype_,
+                                              3,
+                                              dimA,
+                                              strideA));
+        dimA[0] = param_.batch_size_;
+        dimA[1] = param_.bidirectional ? param_.state_size * 2 : 
param_.state_size;
+        dimA[2] = 1;
+        strideA[0] = dimA[2] * dimA[1];
+        strideA[1] = dimA[2];
+        strideA[2] = 1;
+
+        CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i],
+                                              dtype_,
+                                              3,
+                                              dimA,
+                                              strideA));
+        CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i],
+                                              dtype_,
+                                              3,
+                                              dimA,
+                                              strideA));
       }
-      out_shape->push_back(outStateShape);
-      // Deal with lstm cell state
-      if (param_.mode == rnn_enum::kLstm) {
-        mxnet::TShape cellStateShape = dshape;
-        cellStateShape[0] = total_layers;
-        cellStateShape[1] = batch_size;
-        cellStateShape[2] = param_.state_size;
-        out_shape->push_back(cellStateShape);
+      x_desc_vec_ = x_vec;
+      y_desc_vec_ = y_vec;
+      dx_desc_vec_ = dx_vec;
+      dy_desc_vec_ = dy_vec;
+
+      // set the state tensors
+      dimA[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
+      dimA[1] = param_.batch_size_;
+      dimA[2] = param_.state_size;
+      strideA[0] = dimA[2] * dimA[1];
+      strideA[1] = dimA[2];
+      strideA[2] = 1;
+      #if USE_CUDNN_LSTM_PROJ
+      int dimB[3];
+      int strideB[3];
+      dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
+      dimB[1] = param_.batch_size_;
+      dimB[2] = param_.projection_size.has_value() ?
+                param_.projection_size.value() : param_.state_size;
+      strideB[0] = dimB[2] * dimB[1];
+      strideB[1] = dimB[2];
+      strideB[2] = 1;
+      #endif
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #endif
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #endif
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #endif
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #if USE_CUDNN_LSTM_PROJ
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimB,
+                                            strideB));
+      #else
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+      #endif
+      CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_,
+                                            dtype_,
+                                            3,
+                                            dimA,
+                                            strideA));
+
+      // Create Dropout descriptors
+      DType* dropout_states_ = NULL;
+      if (param_.p > 0) {
+         ctx.requested[rnn_enum::kCuDNNDropoutDescSpace].get_cudnn_dropout_desc
+            (&dropout_desc_, s, 1.0f - param_.p, seed_);
+      } else {
+        dropout_byte_ = 0;
       }
-      return true;
-    }
-  }
 
-  bool InferType(std::vector<int> *in_type,
-                 std::vector<int> *out_type,
-                 std::vector<int> *aux_type) const override {
-    CHECK_GE(in_type->size(), 1U);
-    int dtype = (*in_type)[0];
-    CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (size_t i = 0; i < in_type->size(); ++i) {
-      if ((*in_type)[i] == -1) {
-        (*in_type)[i] = dtype;
-      } else {
-        UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
+      CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_,
+                                           param_.p,  // discard probability
+                                           dropout_states_, dropout_byte_,
+                                           seed_));
+
+      // RNN descriptors
+      #if CUDNN_MAJOR >= 6
+      cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
+      CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_,
+                                          rnn_desc_,
+                                          param_.state_size,
+                                          param_.num_layers,
+                                          dropout_desc_,
+                                          input_mode_,
+                                          direction_,
+                                          mode_,
+                                          rnn_algo,
+                                          dtype_));
+      #else
+      CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_,
+                                       param_.state_size,
+                                       param_.num_layers,
+                                       dropout_desc_,
+                                       input_mode_,
+                                       direction_,
+                                       mode_,
+                                       dtype_));
+      #endif
+      #if CUDNN_MAJOR >= 7
+        cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
+        if (cudnn_tensor_core_ && rnn_algo == CUDNN_RNN_ALGO_STANDARD) {
+          math_type = CUDNN_TENSOR_OP_MATH;
+        }
+      #if CUDNN_VERSION >= 7200
+            if (GetEnvAllowTensorCore() && GetEnvAllowTensorCoreConversion() &&
+                (DataType<DType>::kFlag != kFloat16))
+              math_type = CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION;
+      #endif
+        CUDNN_CALL(cudnnSetRNNMatrixMathType(rnn_desc_, math_type));
+      #endif
+      #if USE_CUDNN_LSTM_PROJ
+      if (param_.projection_size.has_value()) {
+        CUDNN_CALL(cudnnSetRNNProjectionLayers(s->dnn_handle_,
+                                               rnn_desc_,
+                                               param_.projection_size.value(),
+                                               0));
       }
+      #endif
+      // Get temp space sizes
+      CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_,
+                                          rnn_desc_,
+                                          param_.seq_length_,
+                                          x_desc_vec_.data(),
+                                          &workspace_byte_));
+      CUDNN_CALL(cudnnGetRNNTrainingReserveSize(s->dnn_handle_,
+                                                rnn_desc_,
+                                                param_.seq_length_,
+                                                x_desc_vec_.data(),
+                                                &reserve_space_byte_));
+      workspace_size_ = workspace_byte_ / sizeof(DType);
+      // Allocate the reserve space
+      reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, 
Context::GPU(s->dev_id));
+      // Allocate the temp space
+      temp_space_ = Storage::Get()->Alloc(workspace_byte_, 
Context::GPU(s->dev_id));
+      // Check that number of params are correct
+      size_t cudnn_param_size;
+      CUDNN_CALL(cudnnGetRNNParamsSize(s->dnn_handle_,
+                                       rnn_desc_,
+                                       x_desc_vec_[0],
+                                       &cudnn_param_size,
+                                       dtype_));
+      CHECK_EQ(w.shape_[0] * sizeof(DType), cudnn_param_size);
+      // Set param descriptors
+      int dim_w[3] = {1, 1, 1};
+      dim_w[0] = w.shape_[0];
+      CUDNN_CALL(cudnnSetFilterNdDescriptor(w_desc_,
+                                            dtype_,
+                                            format_,
+                                            3,
+                                            dim_w));
+      CUDNN_CALL(cudnnSetFilterNdDescriptor(dw_desc_,
+                                            dtype_,
+                                            format_,
+                                            3,
+                                            dim_w));
+
+      // Query weight layout
+      // cudnnFilterDescriptor_t m_desc;
+      // CHECK_EQ(cudnnCreateFilterDescriptor(&m_desc), CUDNN_STATUS_SUCCESS);
+      // DType *p;
+      // int n = 2;
+      // int64_t last = 0;
+      // if (param_.mode == rnn_enum::kLstm) n = 8;
+      // else if (param_.mode == rnn_enum::kGru) n = 6;
+
+      // for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); 
++i) {
+      //   for (int j = 0; j < n; ++j) {
+      //     CHECK_EQ(cudnnGetRNNLinLayerMatrixParams(s->dnn_handle_, 
rnn_desc_,
+      //       i, x_desc_vec_[0], w_desc_, 0, j, m_desc, (void**)&p), 
CUDNN_STATUS_SUCCESS);
+      //     LOG(INFO) << ((int64_t)(p - NULL))/sizeof(DType) - last;
+      //     last = ((int64_t)(p - NULL))/sizeof(DType);
+      //     cudnnDataType_t t;
+      //     cudnnTensorFormat_t f;
+      //     int ndim = 5;
+      //     int dims[5] = {0, 0, 0, 0, 0};
+      //     CHECK_EQ(cudnnGetFilterNdDescriptor(m_desc, ndim, &t, &f, &ndim, 
&dims[0]),
+      //       CUDNN_STATUS_SUCCESS);
+      //     LOG(INFO) << "w: " <<  i << " " << j << " " << ((int64_t)(p - 
NULL))/sizeof(DType);
+      //     for (int i = 0; i < ndim; ++i) LOG(INFO) << dims[i];
+      //   }
+      // }
+
+      // for (int i = 0; i < param_.num_layers*(param_.bidirectional?2:1); 
++i) {
+      //   for (int j = 0; j < n; ++j) {
+      //     CHECK_EQ(cudnnGetRNNLinLayerBiasParams(s->dnn_handle_, rnn_desc_, 
i, x_desc_vec_[0],
+      //       w_desc_, 0, j, m_desc, (void**)&p), CUDNN_STATUS_SUCCESS);
+      //     LOG(INFO) << ((int64_t)(p - NULL))/sizeof(DType) - last;
+      //     last = ((int64_t)(p - NULL))/sizeof(DType);
+      //     LOG(INFO) << "b: " << i << " " << j << " " << ((int64_t)(p - 
NULL))/sizeof(DType);
+      //   }
+      // }
     }
-    out_type->clear();
-    out_type->push_back(dtype);
-    if (!param_.state_outputs) {
-      return true;
+  #endif
+  }
+  #if MXNET_USE_CUDNN_RNN
+  cudnnDataType_t dtype_;
+  bool init_cudnn_;
+  cudnnRNNDescriptor_t rnn_desc_;
+  cudnnRNNMode_t mode_;
+  cudnnDirectionMode_t direction_;
+  cudnnRNNInputMode_t input_mode_;
+  cudnnDropoutDescriptor_t dropout_desc_;
+  Storage::Handle reserve_space_, temp_space_;
+  uint64_t seed_ = 17 + rand() % 4096;  // NOLINT(runtime/threadsafe_fn)
+  size_t workspace_byte_, reserve_space_byte_, dropout_byte_;
+  int workspace_size_;
+  std::vector<cudnnTensorDescriptor_t> x_desc_vec_, y_desc_vec_, dx_desc_vec_, 
dy_desc_vec_;
+  #if USE_CUDNN_LSTM_PROJ
+  cudnnRNNDataDescriptor_t x_data_desc_, y_data_desc_, dx_data_desc_, 
dy_data_desc_;
+  #endif
+  cudnnTensorDescriptor_t hx_desc_, cx_desc_;
+  cudnnTensorDescriptor_t hy_desc_, cy_desc_;
+  cudnnTensorDescriptor_t dhx_desc_, dcx_desc_;
+  cudnnTensorDescriptor_t dhy_desc_, dcy_desc_;
+
+  cudnnFilterDescriptor_t w_desc_, dw_desc_;
+  // Allow TensorCore algo policy
+  bool cudnn_tensor_core_;
+
+  #if CUDNN_MAJOR >= 5
+  cudnnTensorFormat_t format_;
+  #endif
+  #endif
+  bool init_space_, temp_init_space_;
+  size_t reserve_cpu_space_size_, temp_cpu_space_size_;
+  Storage::Handle reserve_cpu_space_, temp_cpu_space_;
+};  //  class RNNOp
+
+static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,
+                                 const Context ctx,
+                                 const mxnet::ShapeVector &in_shapes,
+                                 const std::vector<int> &in_types) {
+  const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+  OpStatePtr state = OpStatePtr();
+  MSHADOW_REAL_TYPE_SWITCH(in_types[rnn_enum::kData], DType, {
+    if (ctx.dev_type == kGPU) {
+      state = OpStatePtr::Create<RNNOp<gpu, DType>>(param, ctx);
     } else {
-      out_type->push_back(dtype);
-      // Deal with lstm cell state
-      if (param_.mode == rnn_enum::kLstm)
-        out_type->push_back(dtype);
-      return true;
+      state = OpStatePtr::Create<RNNOp<cpu, DType>>(param, ctx);
     }
-  }
-
-  OperatorProperty* Copy() const override {
-    auto ptr = new RNNProp();
-    ptr->param_ = param_;
-    return ptr;
-  }
-
-  std::string TypeString() const override {
-    return "RNN";
-  }
+  });
+  return state;
+}
 
-  std::vector<int> DeclareBackwardDependency(
-    const std::vector<int> &out_grad,
-    const std::vector<int> &in_data,
-    const std::vector<int> &out_data) const override {
-    std::vector<int> dep = {in_data[rnn_enum::kData], 
in_data[rnn_enum::kParams],
-        in_data[rnn_enum::kState], out_data[rnn_enum::kOut], 
out_grad[rnn_enum::kOut]};
+template<typename xpu>
+void RNNStatefulCompute(const OpStatePtr& state,
+                        const OpContext& ctx,
+                        const std::vector<TBlob>& inputs,
+                        const std::vector<OpReqType>& req,
+                        const std::vector<TBlob>& outputs) {
+  int dtype = inputs[rnn_enum::kData].type_flag_;
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    RNNOp<xpu, DType>& op = state.get_state<RNNOp<xpu, DType>>();
+    op.Forward(ctx, inputs, req, outputs);
+  });
+}
 
-    if (param_.state_outputs) {
-      dep.push_back(out_data[rnn_enum::kStateOut]);
-      dep.push_back(out_grad[rnn_enum::kStateOut]);
+/*
+index description
+0: x
+1: w
+2: hx
+3: y
+4: dy
+5: hy
+6: dhy
+7: cx
+8: cy
+9: dcy
+*/
+template<typename xpu>
+void RNNStatefulGradCompute(const OpStatePtr& state,
+                            const OpContext& ctx,
+                            const std::vector<TBlob>& inputs,
+                            const std::vector<OpReqType>& req,
+                            const std::vector<TBlob>& outputs) {
+  std::vector<TBlob> in_data(inputs.begin(), inputs.begin() + 3);
+  std::vector<TBlob> out_data{inputs[3]};
+  std::vector<TBlob> out_grad{inputs[4]};
+  const std::vector<TBlob> &in_grad = outputs;
+
+  int dtype = inputs[rnn_enum::kData].type_flag_;
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    RNNOp<xpu, DType>& op = state.get_state<RNNOp<xpu, DType>>();
+    const RNNParam& param = op.param_;
+    int index = 5;
+    if (param.state_outputs) {
+      out_data.push_back(inputs[index++]);
+      out_grad.push_back(inputs[index++]);
     }
 
-    if (param_.mode == rnn_enum::kLstm) {
-      dep.push_back(in_data[rnn_enum::kStateCell]);
-      if (param_.state_outputs) {
-        dep.push_back(out_data[rnn_enum::kStateCellOut]);
-        dep.push_back(out_grad[rnn_enum::kStateCellOut]);
+    if (param.mode == rnn_enum::kLstm) {
+      in_data.push_back(inputs[index++]);
+      if (param.state_outputs) {
+        out_data.push_back(inputs[index++]);
+        out_grad.push_back(inputs[index]);
       }
     }
-    return dep;
-  }
-
-  std::vector<ResourceRequest> ForwardResource(
-      const mxnet::ShapeVector &in_shape) const override {
-    return {ResourceRequest::kTempSpace};
-  }
-
-  std::vector<ResourceRequest> BackwardResource(
-      const mxnet::ShapeVector &in_shape) const override {
-    return {ResourceRequest::kTempSpace};
-  }
 
-  Operator* CreateOperator(Context ctx) const override {
-    LOG(FATAL) << "Not Implemented";
-    return NULL;
-  }
-
-  Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
-                             std::vector<int> *in_type) const override;
+    op.Backward(ctx, out_grad, in_data, out_data, req, in_grad);
+  });
+}
 
- private:
-  RNNParam param_;
-};  // class RNNProp
-#endif  // DMLC_USE_CXX11
 }  // namespace op
 }  // namespace mxnet
+
 #endif  // MXNET_OPERATOR_RNN_INL_H_
diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc
index 621b9eb..74c563a 100644
--- a/src/operator/rnn.cc
+++ b/src/operator/rnn.cc
@@ -27,24 +27,142 @@
 
 namespace mxnet {
 namespace op {
-template<>
-Operator *CreateOp<cpu>(RNNParam param, int dtype) {
-  Operator *op = nullptr;
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    op = new RNNOp<DType>(param);
-  });
-  return op;
+
+DMLC_REGISTER_PARAMETER(RNNParam);
+static inline std::vector<std::string> ListArguments(const RNNParam& param_) {
+  if (param_.mode == rnn_enum::kLstm) {
+    return {"data", "parameters", "state", "state_cell"};
+  } else {
+    return {"data", "parameters", "state"};
+  }
 }
 
-Operator *RNNProp::CreateOperatorEx(Context ctx,
-                                  mxnet::ShapeVector *in_shape,
-                                  std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+static bool RNNShape(const nnvm::NodeAttrs& attrs,
+                     std::vector<TShape> *in_shape,
+                     std::vector<TShape> *out_shape) {
+  const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed);
+  using namespace mshadow;
+  if (param_.mode == rnn_enum::kLstm) {
+    CHECK_EQ(in_shape->size(), 4U) << "Needed input:[data, parameters, state, 
cell_state],"
+        << " got in_shape->size(): " << in_shape->size();
+  } else {
+    CHECK_EQ(in_shape->size(), 3U) <<
+      "Needed input:[data, parameters, state], got in_shape->size(): " << 
in_shape->size();
+  }
+  const TShape &dshape = (*in_shape)[rnn_enum::kData];
+  if (dshape.ndim() ==  0) return false;
+  CHECK_EQ(dshape.ndim(), 3U) \
+      << "Input data should be rank-3 tensor of dim [sequence length, batch 
size, input size]";
+  // data: [sequence len, batch, input dimension]
+  int batch_size = dshape[1];
+  int input_size = dshape[2];
+  int numDirections = param_.bidirectional ? 2 : 1;
+  int total_layers = numDirections * param_.num_layers;  // double for 
bidirectional
+  int layer_size = (param_.projection_size.has_value()) ?
+      param_.projection_size.value() : param_.state_size;
+  SHAPE_ASSIGN_CHECK(*in_shape,
+                     rnn_enum::kState,
+                     Shape3(total_layers, batch_size, layer_size));
+  if (param_.mode == rnn_enum::kLstm) {
+    SHAPE_ASSIGN_CHECK(*in_shape,
+                       rnn_enum::kStateCell,
+                       Shape3(total_layers, batch_size, param_.state_size));
+  }
+
+  // calculate parameter vector length
+  int param_size = GetRnnParamSize(param_.num_layers,
+                                   input_size,
+                                   param_.state_size,
+                                   numDirections,
+                                   param_.mode,
+                                   param_.projection_size);
+  SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
+  out_shape->clear();
+  // output: [sequence len, batch, output size]
+  TShape oshape = dshape;
+  if (param_.projection_size.has_value()) {
+    oshape[2] = numDirections * param_.projection_size.value();
+  } else {
+    oshape[2] = numDirections * param_.state_size;
+  }
+  out_shape->push_back(oshape);
+  if (param_.state_outputs) {
+    // outStateShape: [layer_num, batch, state size]
+    TShape outStateShape = dshape;
+    outStateShape[0] = total_layers;
+    outStateShape[1] = batch_size;
+    if (param_.projection_size.has_value()) {
+      outStateShape[2] = param_.projection_size.value();
+    } else {
+      outStateShape[2] = param_.state_size;
+    }
+    out_shape->push_back(outStateShape);
+    // Deal with lstm cell state
+    if (param_.mode == rnn_enum::kLstm) {
+      TShape cellStateShape = dshape;
+      cellStateShape[0] = total_layers;
+      cellStateShape[1] = batch_size;
+      cellStateShape[2] = param_.state_size;
+      out_shape->push_back(cellStateShape);
+    }
+  }
+  return true;
 }
 
-DMLC_REGISTER_PARAMETER(RNNParam);
+static bool RNNType(const nnvm::NodeAttrs& attrs,
+                    std::vector<int> *in_type,
+                    std::vector<int> *out_type) {
+  const RNNParam& param_ = nnvm::get<RNNParam>(attrs.parsed);
+  if (param_.mode == rnn_enum::kLstm) {
+    CHECK_EQ(in_type->size(), 4U);
+  } else {
+    CHECK_EQ(in_type->size(), 3U);
+  }
+  int dtype = (*in_type)[0];
+  CHECK_NE(dtype, -1) << "First input must have specified type";
+  for (size_t i = 0; i < in_type->size(); ++i) {
+    if ((*in_type)[i] == -1) {
+      TYPE_ASSIGN_CHECK(*in_type, i, dtype);
+    } else {
+      UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
+    }
+  }
+  out_type->clear();
+  out_type->push_back(dtype);
+  if (param_.state_outputs) {
+    out_type->push_back(dtype);
+    // Deal with lstm cell state
+    if (param_.mode == rnn_enum::kLstm)
+      out_type->push_back(dtype);
+  }
+  return true;
+}
 
-MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp)
+struct RNNGrad {
+  const char *op_name;
+  std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr &n,
+          const std::vector<nnvm::NodeEntry> &ograd) const {
+    const RNNParam& params = nnvm::get<RNNParam>(n->attrs.parsed);
+    std::vector<nnvm::NodeEntry> heads{ n->inputs[rnn_enum::kData],
+      n->inputs[rnn_enum::kParams], n->inputs[rnn_enum::kState] };
+    heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kOut, 0});
+    heads.push_back(ograd[rnn_enum::kOut]);
+    if (params.state_outputs) {
+      heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateOut, 0});
+      heads.push_back(ograd[rnn_enum::kStateOut]);
+    }
+    if (params.mode == rnn_enum::kLstm) {
+      heads.push_back(n->inputs[rnn_enum::kStateCell]);
+      if (params.state_outputs) {
+        heads.emplace_back(nnvm::NodeEntry{n, rnn_enum::kStateCellOut, 0});
+        heads.push_back(ograd[rnn_enum::kStateCellOut]);
+      }
+    }
+    return MakeGradNode(op_name, n, heads, n->attrs.dict);
+  }
+};
+
+NNVM_REGISTER_OP(RNN)
 .describe(R"code(Applies recurrent layers to input data. Currently, vanilla 
RNN, LSTM and GRU are
 implemented, with both multi-layer and bidirectional support.
 
@@ -97,7 +215,49 @@ The definition of GRU here is slightly different from paper 
but compatible with
             z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + 
b_{hz}) \\
             n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ 
b_{hn})) \\
             h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\
-            \end{array})code")
+            \end{array}
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<RNNParam>)
+.set_num_inputs([](const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  return params.mode == rnn_enum::kLstm ? 4 : 3;
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  //  kOut
+  int num_outputs = 1;
+  if (params.state_outputs) {
+    // kOut, kStateOut, kStateCellOut
+    num_outputs = (params.mode == rnn_enum::kLstm) ? 3 : 2;
+  }
+
+  return num_outputs;
+})
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  return ListArguments(params);
+})
+.set_attr<mxnet::FInferShape>("FInferShape", RNNShape)
+.set_attr<nnvm::FInferType>("FInferType", RNNType)
+.set_attr<FCreateOpState>("FCreateOpState", CreateRNNState)
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"})
+.set_attr<FResourceRequestEx>("FResourceRequestEx",
+  [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode 
dispatch_mode) {
+    std::vector<ResourceRequest> request;
+    const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+    if (param.p == 0) return request;
+    if (dev_mask == kGPU) {
+#if MXNET_USE_CUDNN_RNN
+      if (1.0f - param.p > 0) {
+        request.emplace_back(ResourceRequest::kCuDNNDropoutDesc);
+        return request;
+      }
+#endif
+    }
+    return request;
+})
 .add_argument("data", "NDArray-or-Symbol", "Input data to RNN")
 .add_argument("parameters", "NDArray-or-Symbol",
               "Vector of all RNN trainable parameters concatenated")
@@ -105,5 +265,15 @@ The definition of GRU here is slightly different from 
paper but compatible with
 .add_argument("state_cell", "NDArray-or-Symbol",
               "initial cell state for LSTM networks (only for LSTM)")
 .add_arguments(RNNParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_RNN)
+.set_num_outputs([](const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  return params.mode == rnn_enum::kLstm ? 4 : 3;
+})
+.set_attr_parser(ParamParser<RNNParam>)
+.set_attr<bool>("TIsLayerOpBackward", true)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", 
RNNStatefulGradCompute<cpu>);
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/rnn.cu b/src/operator/rnn.cu
index 402a8cf..77bb955 100644
--- a/src/operator/rnn.cu
+++ b/src/operator/rnn.cu
@@ -26,24 +26,14 @@
 
 #include "./rnn-inl.h"
 #include <algorithm>
-#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
-#include "./cudnn_rnn-inl.h"
-#endif  // MXNET_USE_CUDNN && CUDNN_MAJOR
 
 namespace mxnet {
 namespace op {
-template<>
-Operator* CreateOp<gpu>(RNNParam param, int dtype) {
-  Operator *op = NULL;
-#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    op = new CuDNNRNNOp<DType>(param);
-  })
-#else
-  LOG(FATAL) << "RNN on GPU is only available for cuDNN at the moment.";
-#endif  // MXNET_USE_CUDNN && CUDNN_MAJOR
-  return op;
-}
 
+NNVM_REGISTER_OP(RNN)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", RNNStatefulCompute<gpu>);
+
+NNVM_REGISTER_OP(_backward_RNN)
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", 
RNNStatefulGradCompute<gpu>);
 }  // namespace op
 }  // namespace mxnet

Reply via email to