lihaofd commented on a change in pull request #14476: Change RNN OP to stateful
URL: https://github.com/apache/incubator-mxnet/pull/14476#discussion_r274253655
 
 

 ##########
 File path: src/operator/rnn.cc
 ##########
 @@ -97,13 +215,69 @@ The definition of GRU here is slightly different from 
paper but compatible with
             z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + 
b_{hz}) \\
             n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ 
b_{hn})) \\
             h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\
-            \end{array})code")
+            \end{array}
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<RNNParam>)
+.set_num_inputs([](const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  return params.mode == rnn_enum::kLstm ? 4 : 3;
+})
+.set_num_outputs([](const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  //  kOut
+  int num_outputs = 1;
+  if (params.state_outputs) {
+    // kOut, kStateOut, kStateCellOut
+    num_outputs = (params.mode == rnn_enum::kLstm) ? 3 : 2;
+  }
+
+  return num_outputs;
+})
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  return ListArguments(params);
+})
+.set_attr<mxnet::FInferShape>("FInferShape", RNNShape)
+.set_attr<nnvm::FInferType>("FInferType", RNNType)
+.set_attr<FCreateOpState>("FCreateOpState", CreateRNNState)
+.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", RNNGrad{"_backward_RNN"})
+.set_attr<FResourceRequestEx>("FResourceRequestEx",
+  [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode 
dispatch_mode) {
+    std::vector<ResourceRequest> request;
+    request.emplace_back(ResourceRequest::kTempSpace);
+    const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
+    if (param.p == 0) return request;
+    if (dev_mask == kGPU) {
+#if MXNET_USE_CUDNN_RNN
+      if (1.0f - param.p > 0) {
+        request.emplace_back(ResourceRequest::kCuDNNDropoutDesc);
+        return request;
+      }
+#endif
+    }
+    return request;
+})
 .add_argument("data", "NDArray-or-Symbol", "Input data to RNN")
 .add_argument("parameters", "NDArray-or-Symbol",
               "Vector of all RNN trainable parameters concatenated")
 .add_argument("state", "NDArray-or-Symbol", "initial hidden state of the RNN")
 .add_argument("state_cell", "NDArray-or-Symbol",
               "initial cell state for LSTM networks (only for LSTM)")
 .add_arguments(RNNParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_RNN)
+.set_num_outputs([](const NodeAttrs& attrs) {
+  const RNNParam& params = nnvm::get<RNNParam>(attrs.parsed);
+  return params.mode == rnn_enum::kLstm ? 4 : 3;
+})
+.set_attr_parser(ParamParser<RNNParam>)
+.set_attr<bool>("TIsLayerOpBackward", true)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 
 Review comment:
   fixed

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to