lihaofd commented on a change in pull request #14476: Change RNN OP to stateful
URL: https://github.com/apache/incubator-mxnet/pull/14476#discussion_r272866427
##########
File path: src/operator/rnn-inl.h
##########
@@ -436,387 +563,901 @@ class RNNOp : public Operator{
if (param_.state_outputs) {
hy_ptr = out_data[rnn_enum::kStateOut].dptr<DType>();
}
- DType* cx_ptr = NULL;
- DType* cy_ptr = NULL;
+ DType * cx_ptr = NULL;
+ DType * cy_ptr = NULL;
+ if (param_.mode == rnn_enum::kLstm)
+ cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+ if (param_.mode == rnn_enum::kLstm && param_.state_outputs)
+ cy_ptr = (out_data[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
- if (param_.mode == rnn_enum::kLstm) {
- cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
- if (param_.state_outputs) {
- cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>();
- }
- }
+ CHECK_EQ(x.CheckContiguous(), true);
+ CHECK_EQ(w.CheckContiguous(), true);
+ CHECK_EQ(hx.CheckContiguous(), true);
+ CHECK_EQ(y.CheckContiguous(), true);
// allocate temp space
- const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_,
param_.batch_size_,
+ const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_,
param_.batch_size_,
param_.state_size,
direction, param_.mode);
- Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
- .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
+ DType* work_cpu_space = NULL;
+ #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+ if (!init_cudnn_) {
+ Init(ctx, s, in_data, out_data);
+ }
+ // Get temp space
+ int temp_size = workspace_size_;
+ Tensor<xpu, 1, DType> temp_space =
+ ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>(
+ mshadow::Shape1(temp_size +
work_cpu_space_size), s);
+
+ work_cpu_space = temp_space.dptr_ + temp_size;
+
+ #if USE_CUDNN_LSTM_PROJ
+ std::vector<int> seqLengthArray(param_.batch_size_, param_.seq_length_);
+ CUDNN_CALL(cudnnSetRNNDataDescriptor(x_data_desc_,
+ dtype_,
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ seqLengthArray.data(),
+ nullptr));
+ int out_size =
+ (param_.projection_size.has_value()) ? param_.projection_size.value() :
param_.state_size;
+ out_size = (param_.bidirectional) ? (out_size * 2) : out_size;
+ CUDNN_CALL(cudnnSetRNNDataDescriptor(y_data_desc_,
+ dtype_,
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ param_.seq_length_,
+ param_.batch_size_,
+ out_size,
+ seqLengthArray.data(),
+ nullptr));
+ if (ctx.is_train) {
+ CUDNN_CALL(cudnnSetRNNDataDescriptor(dx_data_desc_,
+ dtype_,
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ seqLengthArray.data(),
+ nullptr));
+ CUDNN_CALL(cudnnSetRNNDataDescriptor(dy_data_desc_,
+ dtype_,
+
CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_PACKED,
+ param_.seq_length_,
+ param_.batch_size_,
+ out_size,
+ seqLengthArray.data(),
+ nullptr));
+ }
+ #endif
+
+ #if USE_CUDNN_LSTM_PROJ
+ bool clip_state = param_.lstm_state_clip_min.has_value();
+ bool clip_nan = param_.lstm_state_clip_nan;
+ CUDNN_CALL(cudnnRNNSetClip(s->dnn_handle_,
+ rnn_desc_,
+ clip_state ? CUDNN_RNN_CLIP_MINMAX :
CUDNN_RNN_CLIP_NONE,
+ clip_nan ? CUDNN_NOT_PROPAGATE_NAN :
CUDNN_PROPAGATE_NAN,
+ clip_state ? param_.lstm_state_clip_min.value()
: 0.0,
+ clip_state ? param_.lstm_state_clip_max.value()
: 0.0));
+ #endif
if (ctx.is_train) {
- const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers,
direction,
- param_.seq_length_,
param_.batch_size_,
- param_.state_size,
param_.mode);
- if (init_space_ && reserve_space_size_ < r_size) {
- Storage::Get()->Free(reserve_space_);
- init_space_ = false;
- }
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnRNNForwardTrainingEx(s->dnn_handle_,
+ rnn_desc_,
+ x_data_desc_,
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ w_desc_,
+ w.dptr_,
+ y_data_desc_,
+ y.dptr_,
+ hy_desc_,
+ hy_ptr,
+ cy_desc_,
+ cy_ptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ temp_space.dptr_,
+ workspace_byte_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ #else
+ CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_,
+ rnn_desc_,
+ param_.seq_length_,
+ x_desc_vec_.data(),
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ w_desc_,
+ w.dptr_,
+ y_desc_vec_.data(),
+ y.dptr_,
+ hy_desc_,
+ hy_ptr,
+ cy_desc_,
+ cy_ptr,
+ temp_space.dptr_,
+ workspace_byte_,
+ reserve_space_.dptr,
+ reserve_space_byte_));
+ #endif
+ } else {
+ #if USE_CUDNN_LSTM_PROJ
+ CUDNN_CALL(cudnnRNNForwardInferenceEx(s->dnn_handle_,
+ rnn_desc_,
+ x_data_desc_,
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ w_desc_,
+ w.dptr_,
+ y_data_desc_,
+ y.dptr_,
+ hy_desc_,
+ hy_ptr,
+ cy_desc_,
+ cy_ptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ temp_space.dptr_,
+ workspace_byte_));
+ #else
+ CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_,
+ rnn_desc_,
+ param_.seq_length_,
+ x_desc_vec_.data(),
+ x.dptr_,
+ hx_desc_,
+ hx.dptr_,
+ cx_desc_,
+ cx_ptr,
+ w_desc_,
+ w.dptr_,
+ y_desc_vec_.data(),
+ y.dptr_,
+ hy_desc_,
+ hy_ptr,
+ cy_desc_,
+ cy_ptr,
+ temp_space.dptr_,
+ workspace_byte_));
+ #endif
+ }
+ #endif
- if (!init_space_) {
- reserve_space_ = Storage::Get()->Alloc(r_size * sizeof(DType),
Context::CPU());
- reserve_space_size_ = r_size;
- init_space_ = true;
+ if (ctx_.dev_type == kCPU) {
+ if (!work_cpu_space) {
+ Tensor<xpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
+ .get_space_typed<xpu, 1, DType>(Shape1(work_cpu_space_size), s);
+ work_cpu_space = workspace.dptr_;
+ }
+ if (ctx.is_train) {
+ const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers,
direction,
+ param_.seq_length_,
param_.batch_size_,
+ param_.state_size,
param_.mode);
+ if (init_space_ && reserve_cpu_space_size_ < r_size) {
+ Storage::Get()->Free(reserve_cpu_space_);
+ init_space_ = false;
+ }
+ if (!init_space_) {
+ reserve_cpu_space_ = Storage::Get()->Alloc(r_size * sizeof(DType),
Context::CPU());
+ reserve_cpu_space_size_ = r_size;
+ init_space_ = true;
+ }
+
+ DType* reserve_space_ptr =
static_cast<DType*>(reserve_cpu_space_.dptr);
+
+ RNNForwardTraining<DType>(work_cpu_space,
+ reserve_space_ptr,
+ param_.state_outputs,
+ param_.num_layers,
+ direction,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ param_.state_size,
+ x.dptr_,
+ hx.dptr_,
+ cx_ptr,
+ w.dptr_,
+ b_ptr,
+ y.dptr_,
+ hy_ptr,
+ cy_ptr,
+ param_.p,
+ param_.mode);
+ } else {
+ RNNForwardInference<DType>(work_cpu_space,
+ param_.state_outputs,
+ param_.num_layers,
+ direction,
+ param_.seq_length_,
+ param_.batch_size_,
+ param_.input_size_,
+ param_.state_size,
+ x.dptr_,
+ hx.dptr_,
+ cx_ptr,
+ w.dptr_,
+ b_ptr,
+ y.dptr_,
+ hy_ptr,
+ cy_ptr,
+ param_.mode);
}
-
- DType* reserve_space_ptr = static_cast<DType*>(reserve_space_.dptr);
- RNNForwardTraining<DType>(workspace.dptr_,
- reserve_space_ptr,
- param_.state_outputs,
- param_.num_layers,
- direction,
- param_.seq_length_,
- param_.batch_size_,
- param_.input_size_,
- param_.state_size,
- x.dptr_,
- hx.dptr_,
- cx_ptr,
- w.dptr_,
- b_ptr,
- y.dptr_,
- hy_ptr,
- cy_ptr,
- param_.p,
- param_.mode);
- } else {
- RNNForwardInference<DType>(workspace.dptr_,
- param_.state_outputs,
- param_.num_layers,
- direction,
- param_.seq_length_,
- param_.batch_size_,
- param_.input_size_,
- param_.state_size,
- x.dptr_,
- hx.dptr_,
- cx_ptr,
- w.dptr_,
- b_ptr,
- y.dptr_,
- hy_ptr,
- cy_ptr,
- param_.mode);
}
}
- virtual void Backward(const OpContext &ctx,
- const std::vector<TBlob> &out_grad,
- const std::vector<TBlob> &in_data,
- const std::vector<TBlob> &out_data,
- const std::vector<OpReqType> &req,
- const std::vector<TBlob> &in_grad,
- const std::vector<TBlob> &aux_args) {
+ void Backward(const OpContext &ctx,
+ const std::vector<TBlob> &out_grad,
+ const std::vector<TBlob> &in_data,
+ const std::vector<TBlob> &out_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &in_grad) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK(param_.p >= 0.0f && param_.p < 1.0f)
<< "unsupported dropout value, should be 0 <= dropout < 1";
- size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
- size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
- if (!param_.state_outputs) {
- out_expected = 1;
+ size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
+ // kOut
+ size_t num_outputs = 1;
+ if (param_.state_outputs) {
+ // kOut, kStateOut, kStateCellOut
+ num_outputs = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
}
- CHECK_EQ(in_data.size(), in_expected);
- CHECK_EQ(out_data.size(), out_expected);
- CHECK_EQ(in_grad.size(), in_expected);
- CHECK_EQ(out_grad.size(), out_expected);
- CHECK_EQ(req.size(), in_expected);
+
+ CHECK_EQ(in_data.size(), num_inputs);
+ CHECK_EQ(out_data.size(), num_outputs);
+ CHECK_EQ(in_grad.size(), num_inputs);
+ CHECK_EQ(out_grad.size(), num_outputs);
+ CHECK_EQ(req.size(), num_inputs);
CHECK_NE(req[rnn_enum::kData], kAddTo) << "AddTo is not supported for
data";
CHECK_NE(req[rnn_enum::kState], kAddTo) << "AddTo is not supported for
state";
- mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+ Stream<xpu> *s = ctx.get_stream<xpu>();
// get input + output tensors
- Tensor<cpu, 3, DType> x = in_data[rnn_enum::kData].get<cpu, 3, DType>(s);
- Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
- Tensor<cpu, 3, DType> hx = in_data[rnn_enum::kState].get<cpu, 3, DType>(s);
- Tensor<cpu, 3, DType> y = out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);
- Tensor<cpu, 3, DType> dx = in_grad[rnn_enum::kData].get<cpu, 3, DType>(s);
- Tensor<cpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<cpu, 1,
DType>(s);
- Tensor<cpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<cpu, 3,
DType>(s);
- Tensor<cpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<cpu, 3, DType>(s);
- CHECK(x.CheckContiguous());
- CHECK(w.CheckContiguous());
- CHECK(hx.CheckContiguous());
- CHECK(y.CheckContiguous());
- CHECK(dx.CheckContiguous());
- CHECK(dw.CheckContiguous());
- CHECK(dhx.CheckContiguous());
- CHECK(dy.CheckContiguous());
+ Tensor<xpu, 3, DType> x = in_data[rnn_enum::kData].get<xpu, 3, DType>(s);
+ Tensor<xpu, 3, DType> dx = in_grad[rnn_enum::kData].get<xpu, 3, DType>(s);
+ Tensor<xpu, 1, DType> w = in_data[rnn_enum::kParams].get<xpu, 1, DType>(s);
+ Tensor<xpu, 1, DType> dw = in_grad[rnn_enum::kParams].get<xpu, 1,
DType>(s);
+ Tensor<xpu, 3, DType> hx = in_data[rnn_enum::kState].get<xpu, 3, DType>(s);
+ Tensor<xpu, 3, DType> dhx = in_grad[rnn_enum::kState].get<xpu, 3,
DType>(s);
+ Tensor<xpu, 3, DType> y = out_data[rnn_enum::kOut].get<xpu, 3, DType>(s);
+ Tensor<xpu, 3, DType> dy = out_grad[rnn_enum::kOut].get<xpu, 3, DType>(s);
+
+ CHECK_EQ(x.CheckContiguous(), true);
+ CHECK_EQ(w.CheckContiguous(), true);
+ CHECK_EQ(dw.CheckContiguous(), true);
+ CHECK_EQ(hx.CheckContiguous(), true);
+ CHECK_EQ(dhx.CheckContiguous(), true);
+ CHECK_EQ(y.CheckContiguous(), true);
+ CHECK_EQ(dy.CheckContiguous(), true);
+
+ if (req[rnn_enum::kParams] != kAddTo) {
+ dw = mshadow::expr::ScalarExp<DType>(0.0f);
+ }
+
param_.seq_length_ = x.shape_[0];
param_.batch_size_ = x.shape_[1];
param_.input_size_ = x.shape_[2];
const int direction = param_.bidirectional ? 2 : 1;
const int bsize = GetRnnBiasSize(param_.num_layers, param_.state_size,
direction, param_.mode);
+
DType* db_ptr = dw.dptr_ + w.shape_[0] - bsize;
DType * dhy_ptr = NULL;
if (param_.state_outputs) {
dhy_ptr = out_grad[rnn_enum::kStateOut].dptr<DType>();
}
- DType * cx_ptr = NULL;
- DType * dcx_ptr = NULL;
- DType * dcy_ptr = NULL;
+ DType* dcx_ptr = NULL;
+ DType* dcy_ptr = NULL;
+ DType* cx_ptr = NULL;
if (param_.mode == rnn_enum::kLstm) {
CHECK_NE(req[rnn_enum::kStateCell], kAddTo) << "AddTo is not supported
for state cell";
- cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
- dcx_ptr = in_grad[rnn_enum::kStateCell].dptr<DType>();
- if (param_.state_outputs) {
- dcy_ptr = out_grad[rnn_enum::kStateCellOut].dptr<DType>();
- }
+ cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+ dcx_ptr = (in_grad[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
Review comment:
No, cx_ptr is different from hx.dptr_. dcx_ptr is different from dhx.dptr_
too.
cx_ptr and dcx_ptr are only for lstm
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services