lihaofd commented on a change in pull request #14476: Change RNN OP to stateful
URL: https://github.com/apache/incubator-mxnet/pull/14476#discussion_r272925800
##########
File path: src/operator/rnn-inl.h
##########
@@ -438,385 +565,893 @@ class RNNOp : public Operator{
}
DType* cx_ptr = NULL;
DType* cy_ptr = NULL;
-
- if (param_.mode == rnn_enum::kLstm) {
- cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
- if (param_.state_outputs) {
- cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>();
- }
+ if (param_.mode == rnn_enum::kLstm)
+ cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+ if (param_.mode == rnn_enum::kLstm && param_.state_outputs)
+ cy_ptr = (out_data[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
+
+ CHECK_EQ(x.CheckContiguous(), true);
+ CHECK_EQ(w.CheckContiguous(), true);
+ CHECK_EQ(hx.CheckContiguous(), true);
+ CHECK_EQ(y.CheckContiguous(), true);
+
+ #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+ if (!init_cudnn_) {
+ Init(ctx, s, in_data, out_data);
}
-
- // allocate temp space
- const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_,
param_.batch_size_,
- param_.state_size,
direction, param_.mode);
- Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
- .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
-
+ // Get temp space
+ int temp_size = workspace_size_;
+ Tensor<xpu, 1, DType> temp_space =
+ ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>(
+ mshadow::Shape1(temp_size), s);
+
+ #if USE_CUDNN_LSTM_PROJ
+ std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_);
+ CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
+ dtype_,
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ seqLengthArray.data(),
+ nullptr));
+ int out_size =
+ (param_.projection_size.has_value()) ? param_.projection_size.value() :
param_.state_size;
+ out_size = (param_.bidirectional) ? (out_size * 2) : out_size;
+ CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_,
+ dtype_,
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ param_.seq_length_,
+ param_.batch_size_,
+ out_size,
+ seqLengthArray.data(),
+ nullptr));
if (ctx.is_train) {
- const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers,
direction,
- param_.seq_length_,
param_.batch_size_,
- param_.state_size,
param_.mode);
- if (init_space_ && reserve_space_size_ < r_size) {
- Storage::Get()->Free(reserve_space_);
- init_space_ = false;
- }
-
- if (!init_space_) {
- reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType),
Context::CPU());
- reserve_space_size_ = r_size;
- init_space_ = true;
- }
+ CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_,
+ dtype_,
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ seqLengthArray.data(),
+ nullptr));
+ CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_,
+ dtype_,
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ param_.seq_length_,
+ param_.batch_size_,
+ out_size,
+ seqLengthArray.data(),
+ nullptr));
+ }
+ #endif
+
+ #if USE_CUDNN_LSTM_PROJ
+ bool clip_state = param_.lstm_state_clip_min.has_value();
+ bool clip_nan = param_.lstm_state_clip_nan;
+ CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_,
+ rnn_desc_,
+ clip_state ? CUDNN_RNN_CLIP_MINMAX :
CUDNN_RNN_CLIP_NONE,
+ clip_nan ? CUDNN_NOT_PROPAGATE_NAN :
CUDNN_PROPAGATE_NAN,
+ clip_state ? param_.lstm_state_clip_min.value()
: 0.0,
+ clip_state ? param_.lstm_state_clip_max.value()
: 0.0));
+ #endif
- DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
- RNNForwardTraining<DType>(workspace.dptr_,
- reserve_space_ptr,
- param_.state_outputs,
- param_.num_layers,
- direction,
- param_.seq_length_,
- param_.batch_size_,
- param_.input_size_,
- param_.state_size,
- x.dptr_,
- hx.dptr_,
- cx_ptr,
- w.dptr_,
- b_ptr,
- y.dptr_,
- hy_ptr,
- cy_ptr,
- param_.p,
- param_.mode);
+ if (ctx.is_train) {
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
+ rnn_desc_,
+ x_data_desc_,
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ w_desc_,
+ w.dptr_,
+ y_data_desc_,
+ y.dptr_,
+ hy_desc_,
+ hy_ptr,
+ cy_desc_,
+ cy_ptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ temp_space.dptr_,
+ workspace_byte_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ #else
+ CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
+ rnn_desc_,
+ param_.seq_length_,
+ x_desc_vec_.data(),
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ w_desc_,
+ w.dptr_,
+ y_desc_vec_.data(),
+ y.dptr_,
+ hy_desc_,
+ hy_ptr,
+ cy_desc_,
+ cy_ptr,
+ temp_space.dptr_,
+ workspace_byte_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ #endif
} else {
- RNNForwardInference<DType>(workspace.dptr_,
- param_.state_outputs,
- param_.num_layers,
- direction,
- param_.seq_length_,
- param_.batch_size_,
- param_.input_size_,
- param_.state_size,
- x.dptr_,
- hx.dptr_,
- cx_ptr,
- w.dptr_,
- b_ptr,
- y.dptr_,
- hy_ptr,
- cy_ptr,
- param_.mode);
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
+ rnn_desc_,
+ x_data_desc_,
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ w_desc_,
+ w.dptr_,
+ y_data_desc_,
+ y.dptr_,
+ hy_desc_,
+ hy_ptr,
+ cy_desc_,
+ cy_ptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ temp_space.dptr_,
+ workspace_byte_));
+ #else
+ CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
+ rnn_desc_,
+ param_.seq_length_,
+ x_desc_vec_.data(),
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ w_desc_,
+ w.dptr_,
+ y_desc_vec_.data(),
+ y.dptr_,
+ hy_desc_,
+ hy_ptr,
+ cy_desc_,
+ cy_ptr,
+ temp_space.dptr_,
+ workspace_byte_));
+ #endif
+ }
+ #endif
+
+ if (ctx_.dev_type == kCPU) {
+ // allocate temp space
+ const size_t work_cpu_space_size =
+ GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+ param_.state_size, direction, param_.mode);
+ Tensor<xpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
+ .get_space_typed<xpu, 1, DType>(Shape1(work_cpu_space_size), s);
+ DType* work_cpu_space = workspace.dptr_;
+ if (ctx.is_train) {
+ const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers,
direction,
+ param_.seq_length_,
param_.batch_size_,
+ param_.state_size,
param_.mode);
+ if (init_space_ && reserve_cpu_space_size_ < r_size) {
+ Storage::Get()->Free(reserve_cpu_space_);
+ init_space_ = false;
+ }
+ if (!init_space_) {
+ reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType),
Context::CPU());
+ reserve_cpu_space_size_ = r_size;
+ init_space_ = true;
+ }
+
+ DType* reserve_space_ptr =
static_cast<DType*>(reserve_cpu_space_.dptr);
+
+ RNNForwardTraining<DType>(work_cpu_space,
+ reserve_space_ptr,
+ param_.state_outputs,
+ param_.num_layers,
+ direction,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ param_.state_size,
+ x.dptr_,
+ hx.dptr_,
+ cx_ptr,
+ w.dptr_,
+ b_ptr,
+ y.dptr_,
+ hy_ptr,
+ cy_ptr,
+ param_.p,
+ param_.mode);
+ } else {
+ RNNForwardInference<DType>(work_cpu_space,
+ param_.state_outputs,
+ param_.num_layers,
+ direction,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ param_.state_size,
+ x.dptr_,
+ hx.dptr_,
+ cx_ptr,
+ w.dptr_,
+ b_ptr,
+ y.dptr_,
+ hy_ptr,
+ cy_ptr,
+ param_.mode);
+ }
}
}
- virtual void Backward(const OpContext &ctx,
- const std::vector<TBlob> &out_grad,
- const std::vector<TBlob> &in_data,
- const std::vector<TBlob> &out_data,
- const std::vector<OpReqType> &req,
- const std::vector<TBlob> &in_grad,
- const std::vector<TBlob> &aux_args) {
+ void Backward(const OpContext &ctx,
+ const std::vector<TBlob> &out_grad,
+ const std::vector<TBlob> &in_data,
+ const std::vector<TBlob> &out_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &in_grad) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";
- size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
- size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
- if (!param_.state_outputs) {
- out_expected = 1;
+ size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+ // kOut
+ size_t num_outputs = 1;
+ if (param_.state_outputs) {
+ // kOut, kStateOut, kStateCellOut
+ num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
}
- CHECK_EQ(in_data.size(), in_expected);
- CHECK_EQ(out_data.size(), out_expected);
- CHECK_EQ(in_grad.size(), in_expected);
- CHECK_EQ(out_grad.size(), out_expected);
- CHECK_EQ(req.size(), in_expected);
+
+ CHECK_EQ(in_data.size(), num_inputs);
+ CHECK_EQ(out_data.size(), num_outputs);
+ CHECK_EQ(in_grad.size(), num_inputs);
+ CHECK_EQ(out_grad.size(), num_outputs);
+ CHECK_EQ(req.size(), num_inputs);
CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for
data";
CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for
state";
- mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+ Stream<xpu> *s = ctx.get_stream<xpu>();
// get input + output tensors
- Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s);
- Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
- Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s);
- Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);
- Tensor<cpu, 3, DType> dx = in_grad[rnn_enum::kData].get<cpu, 3, DType>(s);
- Tensor<cpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<cpu, 1,
DType>(s);
- Tensor<cpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<cpu, 3,
DType>(s);
- Tensor<cpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<cpu, 3, DType>(s);
- CHECK(x.CheckContiguous());
- CHECK(w.CheckContiguous());
- CHECK(hx.CheckContiguous());
- CHECK(y.CheckContiguous());
- CHECK(dx.CheckContiguous());
- CHECK(dw.CheckContiguous());
- CHECK(dhx.CheckContiguous());
- CHECK(dy.CheckContiguous());
+ Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
+ Tensor<xpu, 3, DType> dx = in_grad[rnn_enum::kData].get<xpu, 3, DType>(s);
+ Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
+ Tensor<xpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<xpu, 1,
DType>(s);
+ Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s);
+ Tensor<xpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<xpu, 3,
DType>(s);
+ Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s);
+ Tensor<xpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<xpu, 3, DType>(s);
+
+ CHECK_EQ(x.CheckContiguous(), true);
+ CHECK_EQ(w.CheckContiguous(), true);
+ CHECK_EQ(dw.CheckContiguous(), true);
+ CHECK_EQ(hx.CheckContiguous(), true);
+ CHECK_EQ(dhx.CheckContiguous(), true);
+ CHECK_EQ(y.CheckContiguous(), true);
+ CHECK_EQ(dy.CheckContiguous(), true);
+ CHECK_EQ(dx.CheckContiguous(), true);
+
+ if (req[rnn_enum::kParams] != kAddTo) {
+ dw = mshadow::expr::ScalarExp<DType>(0.0f);
+ }
+
param_.seq_length_ = x.shape_[0];
param_.batch_size_ = x.shape_[1];
param_.input_size_ = x.shape_[2];
const int direction = param_.bidirectional ? 2 : 1;
const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size,
direction, param_.mode);
+
DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize;
DType * dhy_ptr = NULL;
if (param_.state_outputs) {
dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>();
}
- DType * cx_ptr = NULL;
- DType * dcx_ptr = NULL;
- DType * dcy_ptr = NULL;
+ DType* dcx_ptr = NULL;
+ DType* dcy_ptr = NULL;
+ DType* cx_ptr = NULL;
if (param_.mode == rnn_enum::kLstm) {
CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported
for state cell";
- cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
- dcx_ptr = in_grad[rnn_enum::kStateCell].dptr<DType>();
- if (param_.state_outputs) {
- dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr<DType>();
- }
+ cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+ dcx_ptr = (in_grad[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
}
+ if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs)
+ dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<xpu, 3,
DType>(s)).dptr_;
- // allocate temp space
- const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_,
param_.batch_size_,
- param_.state_size,
direction, param_.mode);
- Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
- .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
-
- size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
- param_.seq_length_,
param_.batch_size_,
- param_.state_size, param_.mode);
- if (!init_space_ || reserve_space_size_ != r_size) {
- LOG(FATAL) << "Check forward init error";
+ #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+ if (!init_cudnn_) {
+ Init(ctx, s, in_data, out_data);
}
- DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
- RNNBackward<DType>(workspace.dptr_,
- reserve_space_ptr,
- param_.num_layers,
- direction,
- param_.seq_length_,
- param_.batch_size_,
- param_.input_size_,
- param_.state_size,
- x.dptr_,
- hx.dptr_,
- cx_ptr,
- w.dptr_,
- y.dptr_,
- dy.dptr_,
- dhy_ptr,
- dcy_ptr,
- dx.dptr_,
- dhx.dptr_,
- dcx_ptr,
- dw.dptr_,
- db_ptr,
- req[rnn_enum::kData],
- req[rnn_enum::kParams],
- req[rnn_enum::kState],
- // State cell should be present for LSTMs, but is
absent for other RNNs.
- param_.mode == rnn_enum::kLstm ?
req[rnn_enum::kStateCell] : kNullOp,
- param_.p,
- param_.mode);
- }
-
- private:
- RNNParam param_;
- bool init_space_;
- size_t reserve_space_size_;
- Storage::Handle reserve_space_;
-}; // class RNNOp
-
-template<typename xpu>
-Operator* CreateOp(RNNParam param, int dtype);
+ // Get temp space
+ int temp_size = workspace_size_;
+ Tensor<xpu, 1, DType> temp_space =
+ ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>(
+ mshadow::Shape1(temp_size), s);
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
+ rnn_desc_,
+ y_data_desc_,
+ y.dptr_,
+ dy_data_desc_,
+ dy.dptr_,
+ nullptr,
+ nullptr,
+ dhy_desc_,
+ dhy_ptr,
+ dcy_desc_,
+ dcy_ptr,
+ w_desc_,
+ w.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ dx_data_desc_,
+ dx.dptr_,
+ dhx_desc_,
+ dhx.dptr_,
+ dcx_desc_,
+ dcx_ptr,
+ nullptr,
+ nullptr,
+ temp_space.dptr_,
+ workspace_byte_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_,
+ rnn_desc_,
+ x_data_desc_,
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ y_data_desc_,
+ y.dptr_,
+ temp_space.dptr_,
+ workspace_byte_,
+ dw_desc_,
+ dw.dptr_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ #else
+ CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
+ rnn_desc_,
+ param_.seq_length_,
+ y_desc_vec_.data(),
+ y.dptr_,
+ dy_desc_vec_.data(),
+ dy.dptr_,
+ dhy_desc_,
+ dhy_ptr,
+ dcy_desc_,
+ dcy_ptr,
+ w_desc_,
+ w.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ dx_desc_vec_.data(),
+ dx.dptr_,
+ dhx_desc_,
+ dhx.dptr_,
+ dcx_desc_,
+ dcx_ptr,
+ temp_space.dptr_,
+ workspace_byte_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_,
+ rnn_desc_,
+ param_.seq_length_,
+ x_desc_vec_.data(),
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ y_desc_vec_.data(),
+ y.dptr_,
+ temp_space.dptr_,
+ workspace_byte_,
+ dw_desc_,
+ dw.dptr_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ #endif
+ #endif
+
+ if (ctx_.dev_type == kCPU) {
+ // allocate temp space
+ const size_t work_cpu_space_size =
+ GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+ param_.state_size, direction, param_.mode);
+ DType* work_cpu_space = NULL;
+ Tensor<xpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
+ .get_space_typed<xpu, 1, DType>(Shape1(work_cpu_space_size), s);
+ work_cpu_space = workspace.dptr_;
+ size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
+ param_.seq_length_,
param_.batch_size_,
+ param_.state_size, param_.mode);
+
+ if (!init_space_ || reserve_cpu_space_size_ != r_size) {
+ LOG(FATAL) << "Check forward init error";
+ }
-#if DMLC_USE_CXX11
-class RNNProp : public OperatorProperty {
- public:
- std::vector<std::string> ListArguments() const override {
- if (param_.mode == rnn_enum::kLstm) {
- return {"data", "parameters", "state", "state_cell"};
- } else {
- return {"data", "parameters", "state"};
+ DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.dptr);
+ RNNBackward<DType>(work_cpu_space,
+ reserve_space_ptr,
+ param_.num_layers,
+ direction,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ param_.state_size,
+ x.dptr_,
+ hx.dptr_,
+ cx_ptr,
+ w.dptr_,
+ y.dptr_,
+ dy.dptr_,
+ dhy_ptr,
+ dcy_ptr,
+ dx.dptr_,
+ dhx.dptr_,
+ dcx_ptr,
+ dw.dptr_,
+ db_ptr,
+ req[rnn_enum::kData],
+ req[rnn_enum::kParams],
+ req[rnn_enum::kState],
+ // State cell should be present for LSTMs, but is
absent for other RNNs.
+ param_.mode == rnn_enum::kLstm ?
req[rnn_enum::kStateCell] : kNullOp,
+ param_.p,
+ param_.mode);
}
}
- std::vector<std::string> ListOutputs() const override {
- std::vector<std::string> outputs = {"output"};
- if (!param_.state_outputs)
- return outputs;
- else
- outputs.emplace_back("state");
- if (param_.mode == rnn_enum::kLstm)
- outputs.emplace_back("state_cell");
- return outputs;
- }
-
- int NumOutputs() const override {
- int mode_num = (param_.mode == rnn_enum::kLstm) ? 2 : 1;
- int num_outputs = param_.state_outputs ? (mode_num + 1) : 1;
- return num_outputs;
- }
-
- void Init(const std::vector<std::pair<std::string, std::string> >& kwargs)
override {
- param_.Init(kwargs);
- }
-
- std::map<std::string, std::string> GetParams() const override {
- return param_.__DICT__();
- }
- bool InferShape(mxnet::ShapeVector *in_shape,
- mxnet::ShapeVector *out_shape,
- mxnet::ShapeVector *aux_shape) const override {
+ private:
+ inline void Init(const OpContext &ctx,
+ mshadow::Stream<xpu> *s,
+ const std::vector<TBlob> &in_data,
+ const std::vector<TBlob> &out_data) {
using namespace mshadow;
- if (param_.mode == rnn_enum::kLstm) {
- CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state,
cell_state]";
- } else {
- CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]";
- }
- const mxnet::TShape &dshape = (*in_shape)[rnn_enum::kData];
- if (dshape.ndim() == 0) return false;
- CHECK_EQ(dshape.ndim(), 3U) \
- << "Input data should be rank-3 tensor of dim [sequence length, batch
size, input size]";
- // data: [sequence len, batch, input dimension]
- int batch_size = dshape[1];
- int input_size = dshape[2];
- int numDirections = param_.bidirectional ? 2 : 1;
- int total_layers = numDirections * param_.num_layers; // double for
bidirectional
- int layer_size = (param_.projection_size.has_value()) ?
- param_.projection_size.value() : param_.state_size;
- SHAPE_ASSIGN_CHECK(*in_shape,
- rnn_enum::kState,
- Shape3(total_layers, batch_size, layer_size));
- if (param_.mode == rnn_enum::kLstm)
- SHAPE_ASSIGN_CHECK(*in_shape,
- rnn_enum::kStateCell,
- Shape3(total_layers, batch_size, param_.state_size));
-
- // calculate parameter vector length
- int param_size = GetRnnParamSize(param_.num_layers,
- input_size,
- param_.state_size,
- numDirections,
- param_.mode,
- param_.projection_size);
- SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
-
- out_shape->clear();
- // output: [sequence len, batch, output size]
- mxnet::TShape oshape = dshape;
- if (param_.projection_size.has_value()) {
- oshape[2] = numDirections * param_.projection_size.value();
- } else {
- oshape[2] = numDirections * param_.state_size;
+ size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+ // kOut
+ size_t num_outputs = 1;
+ if (param_.state_outputs) {
+ // kOut, kStateOut, kStateCellOut
+ num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
}
- out_shape->push_back(oshape);
- if (!param_.state_outputs) {
- return true;
- } else {
- // outStateShape: [layer_num, batch, state size]
- mxnet::TShape outStateShape = dshape;
- outStateShape[0] = total_layers;
- outStateShape[1] = batch_size;
- if (param_.projection_size.has_value()) {
- outStateShape[2] = param_.projection_size.value();
- } else {
- outStateShape[2] = param_.state_size;
+
+ CHECK_EQ(in_data.size(), num_inputs);
+ CHECK_EQ(out_data.size(), num_outputs);
+
+ #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+ #if CUDNN_MAJOR >= 5
+ format_ = CUDNN_TENSOR_NCHW;
+ #endif
+
+ if (!init_cudnn_) {
+ init_cudnn_ = true;
+ // get input + output tensors
+ Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
+ Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1,
DType>(s);
+ param_.seq_length_ = x.shape_[0];
+ param_.batch_size_ = x.shape_[1];
+ param_.input_size_ = x.shape_[2];
+
+ // Tensor Descriptors
+ std::vector<cudnnTensorDescriptor_t> x_vec(param_.seq_length_);
+ std::vector<cudnnTensorDescriptor_t> y_vec(param_.seq_length_);
+ std::vector<cudnnTensorDescriptor_t> dx_vec(param_.seq_length_);
+ std::vector<cudnnTensorDescriptor_t> dy_vec(param_.seq_length_);
+ int dimA[3];
+ int strideA[3];
+ for (int i = 0; i < param_.seq_length_; i++) {
+ CUDNN_CALL(cudnnCreateTensorDescriptor(&x_vec[i]));
+ CUDNN_CALL(cudnnCreateTensorDescriptor(&y_vec[i]));
+ CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_vec[i]));
+ CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_vec[i]));
+
+ dimA[0] = param_.batch_size_;
+ dimA[1] = param_.input_size_;
+ dimA[2] = 1;
+ strideA[0] = dimA[2] * dimA[1];
+ strideA[1] = dimA[2];
+ strideA[2] = 1;
+
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(x_vec[i],
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_vec[i],
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ dimA[0] = param_.batch_size_;
+ dimA[1] = param_.bidirectional ? param_.state_size * 2 :
param_.state_size;
+ dimA[2] = 1;
+ strideA[0] = dimA[2] * dimA[1];
+ strideA[1] = dimA[2];
+ strideA[2] = 1;
+
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i],
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i],
+ dtype_,
+ 3,
+ dimA,
+ strideA));
}
- out_shape->push_back(outStateShape);
- // Deal with lstm cell state
- if (param_.mode == rnn_enum::kLstm) {
- mxnet::TShape cellStateShape = dshape;
- cellStateShape[0] = total_layers;
- cellStateShape[1] = batch_size;
- cellStateShape[2] = param_.state_size;
- out_shape->push_back(cellStateShape);
+ x_desc_vec_ = x_vec;
+ y_desc_vec_ = y_vec;
+ dx_desc_vec_ = dx_vec;
+ dy_desc_vec_ = dy_vec;
+
+ // set the state tensors
+ dimA[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
+ dimA[1] = param_.batch_size_;
+ dimA[2] = param_.state_size;
+ strideA[0] = dimA[2] * dimA[1];
+ strideA[1] = dimA[2];
+ strideA[2] = 1;
+ #if USE_CUDNN_LSTM_PROJ
+ int dimB[3];
+ int strideB[3];
+ dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
+ dimB[1] = param_.batch_size_;
+ dimB[2] = param_.projection_size.has_value() ?
+ param_.projection_size.value() : param_.state_size;
+ strideB[0] = dimB[2] * dimB[1];
+ strideB[1] = dimB[2];
+ strideB[2] = 1;
+ #endif
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
+ dtype_,
+ 3,
+ dimB,
+ strideB));
+ #else
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #endif
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
+ dtype_,
+ 3,
+ dimB,
+ strideB));
+ #else
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #endif
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
+ dtype_,
+ 3,
+ dimB,
+ strideB));
+ #else
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #endif
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
+ dtype_,
+ 3,
+ dimB,
+ strideB));
+ #else
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #endif
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+
+ // Create Dropout descriptors
+ DType* dropout_states_ = NULL;
+ if (param_.p > 0) {
+ CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_byte_));
+ dropout_size_ = dropout_byte_ / sizeof(DType);
+ dropout_states_ =
ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>(
Review comment:
fixed
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services