lihaofd commented on a change in pull request #14476: Change RNN OP to stateful
URL: https://github.com/apache/incubator-mxnet/pull/14476#discussion_r269112331
##########
File path: src/operator/rnn-inl.h
##########
@@ -436,387 +566,897 @@ class RNNOp : public Operator{
if (param_.state_outputs) {
hy_ptr = out_data[rnn_enum::kStateOut].dptr<DType>();
}
- DType* cx_ptr = NULL;
- DType* cy_ptr = NULL;
+ DType * cx_ptr = NULL;
+ DType * cy_ptr = NULL;
+ if (param_.mode == rnn_enum::kLstm)
+ cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+ if (param_.mode == rnn_enum::kLstm && param_.state_outputs)
+ cy_ptr = (out_data[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
- if (param_.mode == rnn_enum::kLstm) {
- cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
- if (param_.state_outputs) {
- cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>();
- }
- }
+ CHECK_EQ(x.CheckContiguous(), true);
+ CHECK_EQ(w.CheckContiguous(), true);
+ CHECK_EQ(hx.CheckContiguous(), true);
+ CHECK_EQ(y.CheckContiguous(), true);
// allocate temp space
- const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_,
param_.batch_size_,
+ const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_,
param_.batch_size_,
param_.state_size,
direction, param_.mode);
- Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
- .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
+ DType* work_cpu_space = NULL;
+ #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+ if (!init_cudnn_) {
+ Init(s, in_data, out_data);
+ }
+ // Get temp space
+ int temp_size = workspace_size_;
+ Tensor<xpu, 1, DType> temp_space =
+ ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>(
+ mshadow::Shape1(temp_size +
work_cpu_space_size), s);
+
+ work_cpu_space = temp_space.dptr_ + temp_size;
+
+ #if USE_CUDNN_LSTM_PROJ
+ std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_);
+ CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
+ dtype_,
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ seqLengthArray.data(),
+ nullptr));
+ int out_size =
+ (param_.projection_size.has_value()) ? param_.projection_size.value() :
param_.state_size;
+ out_size = (param_.bidirectional) ? (out_size * 2) : out_size;
+ CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_,
+ dtype_,
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ param_.seq_length_,
+ param_.batch_size_,
+ out_size,
+ seqLengthArray.data(),
+ nullptr));
+ if (ctx.is_train) {
+ CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_,
+ dtype_,
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ seqLengthArray.data(),
+ nullptr));
+ CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_,
+ dtype_,
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ param_.seq_length_,
+ param_.batch_size_,
+ out_size,
+ seqLengthArray.data(),
+ nullptr));
+ }
+ #endif
+
+ #if USE_CUDNN_LSTM_PROJ
+ bool clip_state = param_.lstm_state_clip_min.has_value();
+ bool clip_nan = param_.lstm_state_clip_nan;
+ CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_,
+ rnn_desc_,
+ clip_state ? CUDNN_RNN_CLIP_MINMAX :
CUDNN_RNN_CLIP_NONE,
+ clip_nan ? CUDNN_NOT_PROPAGATE_NAN :
CUDNN_PROPAGATE_NAN,
+ clip_state ? param_.lstm_state_clip_min.value()
: 0.0,
+ clip_state ? param_.lstm_state_clip_max.value()
: 0.0));
+ #endif
if (ctx.is_train) {
- const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers,
direction,
- param_.seq_length_,
param_.batch_size_,
- param_.state_size,
param_.mode);
- if (init_space_ && reserve_space_size_ < r_size) {
- Storage::Get()->Free(reserve_space_);
- init_space_ = false;
- }
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
+ rnn_desc_,
+ x_data_desc_,
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ w_desc_,
+ w.dptr_,
+ y_data_desc_,
+ y.dptr_,
+ hy_desc_,
+ hy_ptr,
+ cy_desc_,
+ cy_ptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ temp_space.dptr_,
+ workspace_byte_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ #else
+ CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
+ rnn_desc_,
+ param_.seq_length_,
+ x_desc_vec_.data(),
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ w_desc_,
+ w.dptr_,
+ y_desc_vec_.data(),
+ y.dptr_,
+ hy_desc_,
+ hy_ptr,
+ cy_desc_,
+ cy_ptr,
+ temp_space.dptr_,
+ workspace_byte_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ #endif
+ } else {
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
+ rnn_desc_,
+ x_data_desc_,
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ w_desc_,
+ w.dptr_,
+ y_data_desc_,
+ y.dptr_,
+ hy_desc_,
+ hy_ptr,
+ cy_desc_,
+ cy_ptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ temp_space.dptr_,
+ workspace_byte_));
+ #else
+ CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
+ rnn_desc_,
+ param_.seq_length_,
+ x_desc_vec_.data(),
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ w_desc_,
+ w.dptr_,
+ y_desc_vec_.data(),
+ y.dptr_,
+ hy_desc_,
+ hy_ptr,
+ cy_desc_,
+ cy_ptr,
+ temp_space.dptr_,
+ workspace_byte_));
+ #endif
+ }
+ #endif
- if (!init_space_) {
- reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType),
Context::CPU());
- reserve_space_size_ = r_size;
- init_space_ = true;
+ if (ctx_.dev_type == kCPU) {
+ if (!work_cpu_space) {
+ Tensor<xpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
+ .get_space_typed<xpu, 1, DType>(Shape1(work_cpu_space_size), s);
+ work_cpu_space = workspace.dptr_;
+ }
+ if (ctx.is_train) {
+ const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers,
direction,
+ param_.seq_length_,
param_.batch_size_,
+ param_.state_size,
param_.mode);
+ if (init_space_ && reserve_cpu_space_size_ < r_size) {
+ Storage::Get()->Free(reserve_cpu_space_);
+ init_space_ = false;
+ }
+ if (!init_space_) {
+ reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType),
Context::CPU());
+ reserve_cpu_space_size_ = r_size;
+ init_space_ = true;
+ }
+
+ DType* reserve_space_ptr =
static_cast<DType*>(reserve_cpu_space_.dptr);
+
+ RNNForwardTraining<DType>(work_cpu_space,
+ reserve_space_ptr,
+ param_.state_outputs,
+ param_.num_layers,
+ direction,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ param_.state_size,
+ x.dptr_,
+ hx.dptr_,
+ cx_ptr,
+ w.dptr_,
+ b_ptr,
+ y.dptr_,
+ hy_ptr,
+ cy_ptr,
+ param_.p,
+ param_.mode);
+ } else {
+ RNNForwardInference<DType>(work_cpu_space,
+ param_.state_outputs,
+ param_.num_layers,
+ direction,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ param_.state_size,
+ x.dptr_,
+ hx.dptr_,
+ cx_ptr,
+ w.dptr_,
+ b_ptr,
+ y.dptr_,
+ hy_ptr,
+ cy_ptr,
+ param_.mode);
}
-
- DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
- RNNForwardTraining<DType>(workspace.dptr_,
- reserve_space_ptr,
- param_.state_outputs,
- param_.num_layers,
- direction,
- param_.seq_length_,
- param_.batch_size_,
- param_.input_size_,
- param_.state_size,
- x.dptr_,
- hx.dptr_,
- cx_ptr,
- w.dptr_,
- b_ptr,
- y.dptr_,
- hy_ptr,
- cy_ptr,
- param_.p,
- param_.mode);
- } else {
- RNNForwardInference<DType>(workspace.dptr_,
- param_.state_outputs,
- param_.num_layers,
- direction,
- param_.seq_length_,
- param_.batch_size_,
- param_.input_size_,
- param_.state_size,
- x.dptr_,
- hx.dptr_,
- cx_ptr,
- w.dptr_,
- b_ptr,
- y.dptr_,
- hy_ptr,
- cy_ptr,
- param_.mode);
}
}
- virtual void Backward(const OpContext &ctx,
- const std::vector<TBlob> &out_grad,
- const std::vector<TBlob> &in_data,
- const std::vector<TBlob> &out_data,
- const std::vector<OpReqType> &req,
- const std::vector<TBlob> &in_grad,
- const std::vector<TBlob> &aux_args) {
+ void Backward(const OpContext &ctx,
+ const std::vector<TBlob> &out_grad,
+ const std::vector<TBlob> &in_data,
+ const std::vector<TBlob> &out_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &in_grad) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";
- size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
- size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
- if (!param_.state_outputs) {
- out_expected = 1;
+ size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+ // kOut
+ size_t num_outputs = 1;
+ if (param_.state_outputs) {
+ // kOut, kStateOut, kStateCellOut
+ num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
}
- CHECK_EQ(in_data.size(), in_expected);
- CHECK_EQ(out_data.size(), out_expected);
- CHECK_EQ(in_grad.size(), in_expected);
- CHECK_EQ(out_grad.size(), out_expected);
- CHECK_EQ(req.size(), in_expected);
+
+ CHECK_EQ(in_data.size(), num_inputs);
+ CHECK_EQ(out_data.size(), num_outputs);
+ CHECK_EQ(in_grad.size(), num_inputs);
+ CHECK_EQ(out_grad.size(), num_outputs);
+ CHECK_EQ(req.size(), num_inputs);
CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for
data";
CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for
state";
- mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+ Stream<xpu> *s = ctx.get_stream<xpu>();
// get input + output tensors
- Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s);
- Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
- Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s);
- Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);
- Tensor<cpu, 3, DType> dx = in_grad[rnn_enum::kData].get<cpu, 3, DType>(s);
- Tensor<cpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<cpu, 1,
DType>(s);
- Tensor<cpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<cpu, 3,
DType>(s);
- Tensor<cpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<cpu, 3, DType>(s);
- CHECK(x.CheckContiguous());
- CHECK(w.CheckContiguous());
- CHECK(hx.CheckContiguous());
- CHECK(y.CheckContiguous());
- CHECK(dx.CheckContiguous());
- CHECK(dw.CheckContiguous());
- CHECK(dhx.CheckContiguous());
- CHECK(dy.CheckContiguous());
+ Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
+ Tensor<xpu, 3, DType> dx = in_grad[rnn_enum::kData].get<xpu, 3, DType>(s);
+ Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
+ Tensor<xpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<xpu, 1,
DType>(s);
+ Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s);
+ Tensor<xpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<xpu, 3,
DType>(s);
+ Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s);
+ Tensor<xpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<xpu, 3, DType>(s);
+
+ CHECK_EQ(x.CheckContiguous(), true);
+ CHECK_EQ(w.CheckContiguous(), true);
+ CHECK_EQ(dw.CheckContiguous(), true);
+ CHECK_EQ(hx.CheckContiguous(), true);
+ CHECK_EQ(dhx.CheckContiguous(), true);
+ CHECK_EQ(y.CheckContiguous(), true);
+ CHECK_EQ(dy.CheckContiguous(), true);
+
+ if (req[rnn_enum::kParams] != kAddTo) {
+ dw = mshadow::expr::ScalarExp<DType>(0.0f);
+ }
+
param_.seq_length_ = x.shape_[0];
param_.batch_size_ = x.shape_[1];
param_.input_size_ = x.shape_[2];
const int direction = param_.bidirectional ? 2 : 1;
const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size,
direction, param_.mode);
+
DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize;
DType * dhy_ptr = NULL;
if (param_.state_outputs) {
dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>();
}
- DType * cx_ptr = NULL;
- DType * dcx_ptr = NULL;
- DType * dcy_ptr = NULL;
+ DType* dcx_ptr = NULL;
+ DType* dcy_ptr = NULL;
+ DType* cx_ptr = NULL;
if (param_.mode == rnn_enum::kLstm) {
CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported
for state cell";
- cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
- dcx_ptr = in_grad[rnn_enum::kStateCell].dptr<DType>();
- if (param_.state_outputs) {
- dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr<DType>();
- }
+ cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+ dcx_ptr = (in_grad[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
}
+ if ((param_.mode == rnn_enum::kLstm) && param_.state_outputs)
+ dcy_ptr = (out_grad[rnn_enum::kStateCellOut].get<xpu, 3,
DType>(s)).dptr_;
// allocate temp space
- const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_,
param_.batch_size_,
- param_.state_size,
direction, param_.mode);
- Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
- .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
-
- size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
- param_.seq_length_,
param_.batch_size_,
- param_.state_size, param_.mode);
- if (!init_space_ || reserve_space_size_ != r_size) {
- LOG(FATAL) << "Check forward init error";
+ const size_t work_cpu_space_size =
+ GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
+ param_.state_size, direction, param_.mode);
+ DType* work_cpu_space = NULL;
+ #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+ if (!init_cudnn_) {
+ Init(s, in_data, out_data);
}
- DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
- RNNBackward<DType>(workspace.dptr_,
- reserve_space_ptr,
- param_.num_layers,
- direction,
- param_.seq_length_,
- param_.batch_size_,
- param_.input_size_,
- param_.state_size,
- x.dptr_,
- hx.dptr_,
- cx_ptr,
- w.dptr_,
- y.dptr_,
- dy.dptr_,
- dhy_ptr,
- dcy_ptr,
- dx.dptr_,
- dhx.dptr_,
- dcx_ptr,
- dw.dptr_,
- db_ptr,
- req[rnn_enum::kData],
- req[rnn_enum::kParams],
- req[rnn_enum::kState],
- // State cell should be present for LSTMs, but is
absent for other RNNs.
- param_.mode == rnn_enum::kLstm ?
req[rnn_enum::kStateCell] : kNullOp,
- param_.p,
- param_.mode);
- }
-
- private:
- RNNParam param_;
- bool init_space_;
- size_t reserve_space_size_;
- Storage::Handle reserve_space_;
-}; // class RNNOp
+ // Get temp space
+ int temp_size = workspace_size_;
+ Tensor<xpu, 1, DType> temp_space =
+ ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>(
+ mshadow::Shape1(temp_size +
work_cpu_space_size), s);
+ work_cpu_space = temp_space.dptr_ + temp_size;
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnRNNBackwardDataEx(s->dnn_handle_,
+ rnn_desc_,
+ y_data_desc_,
+ y.dptr_,
+ dy_data_desc_,
+ dy.dptr_,
+ nullptr,
+ nullptr,
+ dhy_desc_,
+ dhy_ptr,
+ dcy_desc_,
+ dcy_ptr,
+ w_desc_,
+ w.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ dx_data_desc_,
+ dx.dptr_,
+ dhx_desc_,
+ dhx.dptr_,
+ dcx_desc_,
+ dcx_ptr,
+ nullptr,
+ nullptr,
+ temp_space.dptr_,
+ workspace_byte_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ CUDNN_CALL(cudnnRNNBackwardWeightsEx(s->dnn_handle_,
+ rnn_desc_,
+ x_data_desc_,
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ y_data_desc_,
+ y.dptr_,
+ temp_space.dptr_,
+ workspace_byte_,
+ dw_desc_,
+ dw.dptr_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ #else
+ CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_,
+ rnn_desc_,
+ param_.seq_length_,
+ y_desc_vec_.data(),
+ y.dptr_,
+ dy_desc_vec_.data(),
+ dy.dptr_,
+ dhy_desc_,
+ dhy_ptr,
+ dcy_desc_,
+ dcy_ptr,
+ w_desc_,
+ w.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ dx_desc_vec_.data(),
+ dx.dptr_,
+ dhx_desc_,
+ dhx.dptr_,
+ dcx_desc_,
+ dcx_ptr,
+ temp_space.dptr_,
+ workspace_byte_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_,
+ rnn_desc_,
+ param_.seq_length_,
+ x_desc_vec_.data(),
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ y_desc_vec_.data(),
+ y.dptr_,
+ temp_space.dptr_,
+ workspace_byte_,
+ dw_desc_,
+ dw.dptr_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ #endif
+ #endif
+
+ if (ctx_.dev_type == kCPU) {
+ if (!work_cpu_space) {
+ Tensor<xpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
+ .get_space_typed<xpu, 1, DType>(Shape1(work_cpu_space_size), s);
+ work_cpu_space = workspace.dptr_;
+ }
+ size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
+ param_.seq_length_,
param_.batch_size_,
+ param_.state_size, param_.mode);
-template<typename xpu>
-Operator* CreateOp(RNNParam param, int dtype);
+ if (!init_space_ || reserve_cpu_space_size_ != r_size) {
+ LOG(FATAL) << "Check forward init error";
+ }
-#if DMLC_USE_CXX11
-class RNNProp : public OperatorProperty {
- public:
- std::vector<std::string> ListArguments() const override {
- if (param_.mode == rnn_enum::kLstm) {
- return {"data", "parameters", "state", "state_cell"};
- } else {
- return {"data", "parameters", "state"};
+ DType* reserve_space_ptr = static_cast<DType*>(reserve_cpu_space_.dptr);
+ RNNBackward<DType>(work_cpu_space,
+ reserve_space_ptr,
+ param_.num_layers,
+ direction,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ param_.state_size,
+ x.dptr_,
+ hx.dptr_,
+ cx_ptr,
+ w.dptr_,
+ y.dptr_,
+ dy.dptr_,
+ dhy_ptr,
+ dcy_ptr,
+ dx.dptr_,
+ dhx.dptr_,
+ dcx_ptr,
+ dw.dptr_,
+ db_ptr,
+ req[rnn_enum::kData],
+ req[rnn_enum::kParams],
+ req[rnn_enum::kState],
+ // State cell should be present for LSTMs, but is
absent for other RNNs.
+ param_.mode == rnn_enum::kLstm ?
req[rnn_enum::kStateCell] : kNullOp,
+ param_.p,
+ param_.mode);
}
}
- std::vector<std::string> ListOutputs() const override {
- std::vector<std::string> outputs = {"output"};
- if (!param_.state_outputs)
- return outputs;
- else
- outputs.emplace_back("state");
- if (param_.mode == rnn_enum::kLstm)
- outputs.emplace_back("state_cell");
- return outputs;
- }
-
- int NumOutputs() const override {
- int mode_num = (param_.mode == rnn_enum::kLstm) ? 2 : 1;
- int num_outputs = param_.state_outputs ? (mode_num + 1) : 1;
- return num_outputs;
- }
-
- void Init(const std::vector<std::pair<std::string, std::string> >& kwargs)
override {
- param_.Init(kwargs);
- }
- std::map<std::string, std::string> GetParams() const override {
- return param_.__DICT__();
- }
-
- bool InferShape(mxnet::ShapeVector *in_shape,
- mxnet::ShapeVector *out_shape,
- mxnet::ShapeVector *aux_shape) const override {
+ private:
+ inline void Init(mshadow::Stream<xpu> *s,
+ const std::vector<TBlob> &in_data,
+ const std::vector<TBlob> &out_data) {
using namespace mshadow;
- if (param_.mode == rnn_enum::kLstm) {
- CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state,
cell_state]";
- } else {
- CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]";
- }
- const mxnet::TShape &dshape = (*in_shape)[rnn_enum::kData];
- if (dshape.ndim() == 0) return false;
- CHECK_EQ(dshape.ndim(), 3U) \
- << "Input data should be rank-3 tensor of dim [sequence length, batch
size, input size]";
- // data: [sequence len, batch, input dimension]
- int batch_size = dshape[1];
- int input_size = dshape[2];
- int numDirections = param_.bidirectional ? 2 : 1;
- int total_layers = numDirections * param_.num_layers; // double for
bidirectional
- int layer_size = (param_.projection_size.has_value()) ?
- param_.projection_size.value() : param_.state_size;
- SHAPE_ASSIGN_CHECK(*in_shape,
- rnn_enum::kState,
- Shape3(total_layers, batch_size, layer_size));
- if (param_.mode == rnn_enum::kLstm)
- SHAPE_ASSIGN_CHECK(*in_shape,
- rnn_enum::kStateCell,
- Shape3(total_layers, batch_size, param_.state_size));
-
- // calculate parameter vector length
- int param_size = GetRnnParamSize(param_.num_layers,
- input_size,
- param_.state_size,
- numDirections,
- param_.mode,
- param_.projection_size);
- SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size));
-
- out_shape->clear();
- // output: [sequence len, batch, output size]
- mxnet::TShape oshape = dshape;
- if (param_.projection_size.has_value()) {
- oshape[2] = numDirections * param_.projection_size.value();
- } else {
- oshape[2] = numDirections * param_.state_size;
+ size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+ // kOut
+ size_t num_outputs = 1;
+ if (param_.state_outputs) {
+ // kOut, kStateOut, kStateCellOut
+ num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
}
- out_shape->push_back(oshape);
- if (!param_.state_outputs) {
- return true;
- } else {
- // outStateShape: [layer_num, batch, state size]
- mxnet::TShape outStateShape = dshape;
- outStateShape[0] = total_layers;
- outStateShape[1] = batch_size;
- if (param_.projection_size.has_value()) {
- outStateShape[2] = param_.projection_size.value();
+
+ CHECK_EQ(in_data.size(), num_inputs);
+ CHECK_EQ(out_data.size(), num_outputs);
+
+ #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+ #if CUDNN_MAJOR >= 5
+ format_ = CUDNN_TENSOR_NCHW;
+ #endif
+
+ if (!init_cudnn_) {
+ init_cudnn_ = true;
+ // get input + output tensors
+ Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
+ Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1,
DType>(s);
+ param_.seq_length_ = x.shape_[0];
+ param_.batch_size_ = x.shape_[1];
+ param_.input_size_ = x.shape_[2];
+
+ // Tensor Descriptors
+ std::vector<cudnnTensorDescriptor_t> x_vec(param_.seq_length_);
+ std::vector<cudnnTensorDescriptor_t> y_vec(param_.seq_length_);
+ std::vector<cudnnTensorDescriptor_t> dx_vec(param_.seq_length_);
+ std::vector<cudnnTensorDescriptor_t> dy_vec(param_.seq_length_);
+ int dimA[3];
+ int strideA[3];
+ for (int i = 0; i < param_.seq_length_; i++) {
+ CUDNN_CALL(cudnnCreateTensorDescriptor(&x_vec[i]));
+ CUDNN_CALL(cudnnCreateTensorDescriptor(&y_vec[i]));
+ CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_vec[i]));
+ CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_vec[i]));
+
+ dimA[0] = param_.batch_size_;
+ dimA[1] = param_.input_size_;
+ dimA[2] = 1;
+ strideA[0] = dimA[2] * dimA[1];
+ strideA[1] = dimA[2];
+ strideA[2] = 1;
+
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(x_vec[i],
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_vec[i],
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ dimA[0] = param_.batch_size_;
+ dimA[1] = param_.bidirectional ? param_.state_size * 2 :
param_.state_size;
+ dimA[2] = 1;
+ strideA[0] = dimA[2] * dimA[1];
+ strideA[1] = dimA[2];
+ strideA[2] = 1;
+
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i],
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i],
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ }
+ x_desc_vec_ = x_vec;
+ y_desc_vec_ = y_vec;
+ dx_desc_vec_ = dx_vec;
+ dy_desc_vec_ = dy_vec;
+
+ // set the state tensors
+ dimA[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
+ dimA[1] = param_.batch_size_;
+ dimA[2] = param_.state_size;
+ strideA[0] = dimA[2] * dimA[1];
+ strideA[1] = dimA[2];
+ strideA[2] = 1;
+ #if USE_CUDNN_LSTM_PROJ
+ int dimB[3];
+ int strideB[3];
+ dimB[0] = param_.num_layers * (param_.bidirectional ? 2 : 1);
+ dimB[1] = param_.batch_size_;
+ dimB[2] = param_.projection_size.has_value() ?
+ param_.projection_size.value() : param_.state_size;
+ strideB[0] = dimB[2] * dimB[1];
+ strideB[1] = dimB[2];
+ strideB[2] = 1;
+ #endif
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
+ dtype_,
+ 3,
+ dimB,
+ strideB));
+ #else
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #endif
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
+ dtype_,
+ 3,
+ dimB,
+ strideB));
+ #else
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #endif
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
+ dtype_,
+ 3,
+ dimB,
+ strideB));
+ #else
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #endif
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
+ dtype_,
+ 3,
+ dimB,
+ strideB));
+ #else
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+ #endif
+ CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_,
+ dtype_,
+ 3,
+ dimA,
+ strideA));
+
+ // Create Dropout descriptors
+ if (param_.p > 0) {
+ CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_byte_));
+ dropout_size_ = dropout_byte_ / sizeof(DType);
+ dropout_states_ = Storage::Get()->Alloc(dropout_byte_,
Context::GPU(s->dev_id));
Review comment:
using ctx.requested to fix it
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services