lihaofd commented on a change in pull request #14476: Change RNN OP to stateful
URL: https://github.com/apache/incubator-mxnet/pull/14476#discussion_r274253655
##########
File path: src/operator/rnn.cc
##########
@@ -97,13 +215,69 @@ The definition of GRU here is slightly different from
paper but compatible with
z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} +
b_{hz}) \\
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+
b_{hn})) \\
h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\
- \end{array})code")
+ \end{array}
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<RNNParam>)
+.set_num_inputs([](const NodeAttrs& attrs) {
+ const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+ return params.mode == rnn_enum::kLstm ? 4 : 3;
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+ const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+ // kOut
+ int num_outputs = 1;
+ if (params.state_outputs) {
+ // kOut, kStateOut, kStateCellOut
+ num_outputs = (params.mode == rnn_enum::kLstm) ? 3 : 2;
+ }
+
+ return num_outputs;
+})
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+ [](const NodeAttrs& attrs) {
+ const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+ return ListArguments(params);
+})
+.set_attr<mxnet::FInferShape>("FInferShape", RNNShape)
+.set_attr<nnvm::FInferType>("FInferType", RNNType)
+.set_attr<FCreateOpState>("FCreateOpState", CreateRNNState)
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"})
+.set_attr<FResourceRequestEx>("FResourceRequestEx",
+ [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode
dispatch_mode) {
+ std::vector<ResourceRequest> request;
+ request.emplace_back(ResourceRequest::kTempSpace);
+ const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+ if (param.p == 0) return request;
+ if (dev_mask == kGPU) {
+#if MXNET_USE_CUDNN_RNN
+ if (1.0f - param.p > 0) {
+ request.emplace_back(ResourceRequest::kCuDNNDropoutDesc);
+ return request;
+ }
+#endif
+ }
+ return request;
+})
.add_argument("data", "NDArray-or-Symbol", "Input data to RNN")
.add_argument("parameters", "NDArray-or-Symbol",
"Vector of all RNN trainable parameters concatenated")
.add_argument("state", "NDArray-or-Symbol", "initial hidden state of the RNN")
.add_argument("state_cell", "NDArray-or-Symbol",
"initial cell state for LSTM networks (only for LSTM)")
.add_arguments(RNNParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_RNN)
+.set_num_outputs([](const NodeAttrs& attrs) {
+ const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+ return params.mode == rnn_enum::kLstm ? 4 : 3;
+})
+.set_attr_parser(ParamParser<RNNParam>)
+.set_attr<bool>("TIsLayerOpBackward", true)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
Review comment:
fixed
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services