zheng-da commented on a change in pull request #9977: Cpu lstm inference
URL: https://github.com/apache/incubator-mxnet/pull/9977#discussion_r172712476
 
 

 ##########
 File path: src/operator/rnn-inl.h
 ##########
 @@ -144,15 +149,224 @@ class RNNOp : public Operator {
                         const std::vector<OpReqType> &req,
                         const std::vector<TBlob> &in_grad,
                         const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
     // TODO(sbodenstein): add MShadow implementation
   }
 
  private:
   RNNParam param_;
 };  // class RNNOp
 
+template<typename DType>
+class RNNOp<cpu, DType> : public Operator {
+ public:
+  explicit RNNOp(RNNParam param) {
+    this->param_ = param;
+    // RNN Mode
+    switch (param_.mode) {
+      case rnn_enum::kLstm:
+        break;
+      default:
+        LOG(FATAL) << "only LSTM is implmented on CPU";
+    }
+    if (param_.mode == rnn_enum::kLstm)
+      param_.lstm_q_ = true;
+    else
+      param_.lstm_q_ = false;
+  }
+
+  virtual void Forward(const OpContext &ctx,
+                       const std::vector<TBlob> &in_data,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &out_data,
+                       const std::vector<TBlob> &aux_args) {
+    // Layout TNC
+
+    size_t in_expected = param_.lstm_q_ ? 4 : 3;
+    size_t out_expected = param_.lstm_q_ ? 3 : 2;
+
+    if (!param_.state_outputs)
+      LOG(FATAL) << "no state outputs is currently not supported for cpu.";
+
+    CHECK_EQ(req[rnn_enum::kOut], kWriteTo);
+    CHECK_EQ(in_data.size(), in_expected);
+    CHECK_EQ(out_data.size(), out_expected);
+
+    mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+    // get input + output tensors
+    // w layout i2h_w, h2h_w, i2h_b, h2h_b
+    Tensor<cpu, 3, DType> x =
+        in_data[rnn_enum::kData].get<cpu, 3, DType>(s);  // TNC
+    Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
+    Tensor<cpu, 3, DType> hx =
+        in_data[rnn_enum::kState].get<cpu, 3, DType>(s);  // LNC
+    Tensor<cpu, 3, DType> y =
+        out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);  // TNC
+    int64_t seq_len = x.shape_[0];
+    int64_t num_layers = hx.shape_[0];
+    int64_t batch_size = x.shape_[1];
+    int64_t h_channel = hx.shape_[2];
+    int64_t in_channel = x.shape_[2];
+    Tensor<cpu, 2, DType> x_flatten = in_data[rnn_enum::kData]
+      .get_with_shape<cpu, 2, DType>(
+          mshadow::Shape2(seq_len * batch_size, in_channel), s);  // (T*N)C
+    Tensor<cpu, 2, DType> y_flatten = out_data[rnn_enum::kOut]
+      .get_with_shape<cpu, 2, DType>(
+          mshadow::Shape2(
+              y.shape_[0] * y.shape_[1], y.shape_[2]), s);  // (T*N)C
+
+    CHECK_EQ(x.CheckContiguous(), true);
+    CHECK_EQ(w.CheckContiguous(), true);
+    CHECK_EQ(hx.CheckContiguous(), true);
+    CHECK_EQ(y.CheckContiguous(), true);
+
+    if (ctx.is_train)
+      LOG(FATAL) << "only inference mode is available for cpu at the moment.";
 
 Review comment:
   you can do CHECK(!ctx.is_train) << "..."

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to