sxjscience commented on a change in pull request #14476: Change RNN OP to 
stateful
URL: https://github.com/apache/incubator-mxnet/pull/14476#discussion_r272832692
 
 

 ##########
 File path: src/operator/rnn-inl.h
 ##########
 @@ -436,387 +563,901 @@ class RNNOp : public Operator{
     if (param_.state_outputs) {
       hy_ptr = out_data[rnn_enum::kStateOut].dptr<DType>();
     }
-    DType* cx_ptr = NULL;
-    DType* cy_ptr = NULL;
+    DType * cx_ptr = NULL;
+    DType * cy_ptr = NULL;
+    if (param_.mode == rnn_enum::kLstm)
+      cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+    if (param_.mode == rnn_enum::kLstm && param_.state_outputs)
+      cy_ptr = (out_data[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
 
-    if (param_.mode == rnn_enum::kLstm) {
-      cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
-      if (param_.state_outputs) {
-        cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>();
-      }
-    }
+    CHECK_EQ(x.CheckContiguous(), true);
+    CHECK_EQ(w.CheckContiguous(), true);
+    CHECK_EQ(hx.CheckContiguous(), true);
+    CHECK_EQ(y.CheckContiguous(), true);
 
     // allocate temp space
-    const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, 
param_.batch_size_,
+    const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, 
param_.batch_size_,
                                                       param_.state_size, 
direction, param_.mode);
-    Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
-        .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
+    DType* work_cpu_space = NULL;
+    #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+    if (!init_cudnn_) {
+      Init(ctx, s, in_data, out_data);
+    }
+    // Get temp space
+    int temp_size = workspace_size_;
+    Tensor<xpu, 1, DType> temp_space =
+      ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>(
+                              mshadow::Shape1(temp_size + 
work_cpu_space_size), s);
+
+    work_cpu_space = temp_space.dptr_ + temp_size;
 
 Review comment:
   I find that this part is only triggered when `defined(__CUDACC__)`. Thus 
work_cpu_space does not seem to be necessary.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to