lihaofd commented on a change in pull request #14476: Change RNN OP to stateful
URL: https://github.com/apache/incubator-mxnet/pull/14476#discussion_r272868052
 
 

 ##########
 File path: src/operator/rnn-inl.h
 ##########
 @@ -436,387 +563,901 @@ class RNNOp : public Operator{
     if (param_.state_outputs) {
       hy_ptr = out_data[rnn_enum::kStateOut].dptr<DType>();
     }
-    DType* cx_ptr = NULL;
-    DType* cy_ptr = NULL;
+    DType * cx_ptr = NULL;
+    DType * cy_ptr = NULL;
+    if (param_.mode == rnn_enum::kLstm)
+      cx_ptr = (in_data[rnn_enum::kStateCell].get<xpu, 3, DType>(s)).dptr_;
+    if (param_.mode == rnn_enum::kLstm && param_.state_outputs)
+      cy_ptr = (out_data[rnn_enum::kStateCellOut].get<xpu, 3, DType>(s)).dptr_;
 
-    if (param_.mode == rnn_enum::kLstm) {
-      cx_ptr = in_data[rnn_enum::kStateCell].dptr<DType>();
-      if (param_.state_outputs) {
-        cy_ptr = out_data[rnn_enum::kStateCellOut].dptr<DType>();
-      }
-    }
+    CHECK_EQ(x.CheckContiguous(), true);
+    CHECK_EQ(w.CheckContiguous(), true);
+    CHECK_EQ(hx.CheckContiguous(), true);
+    CHECK_EQ(y.CheckContiguous(), true);
 
     // allocate temp space
-    const size_t workspace_size = GetRNNWorkspaceSize(param_.seq_length_, 
param_.batch_size_,
+    const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, 
param_.batch_size_,
                                                       param_.state_size, 
direction, param_.mode);
-    Tensor<cpu, 1, DType> workspace = ctx.requested[rnn_enum::kTempSpace]
-        .get_space_typed<cpu, 1, DType>(Shape1(workspace_size), s);
+    DType* work_cpu_space = NULL;
+    #if MXNET_USE_CUDNN_RNN && defined(__CUDACC__)
+    if (!init_cudnn_) {
+      Init(ctx, s, in_data, out_data);
+    }
+    // Get temp space
+    int temp_size = workspace_size_;
+    Tensor<xpu, 1, DType> temp_space =
+      ctx.requested[rnn_enum::kTempSpace].get_space_typed<xpu, 1, DType>(
+                              mshadow::Shape1(temp_size + 
work_cpu_space_size), s);
+
+    work_cpu_space = temp_space.dptr_ + temp_size;
 
 Review comment:
   fixed

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to