piiswrong commented on a change in pull request #10311: [MXNET-107]Fused GRU 
implementation for CPU
URL: https://github.com/apache/incubator-mxnet/pull/10311#discussion_r192892126
 
 

 ##########
 File path: src/operator/rnn_impl.h
 ##########
 @@ -454,4 +458,796 @@ void LstmBackward(DType* ws,
     dy_ptr = dx.dptr_;
   }
 }
+
+template<typename DType>
+void GruForwardInferenceSingleLayer(DType* ws,
+                                    DType* tmp_buf,
+                                    bool state_outputs,
+                                    const int D,
+                                    const int T,
+                                    const int N,
+                                    const int I,
+                                    const int H,
+                                    const Tensor<cpu, 2, DType> &x,
+                                    const Tensor<cpu, 2, DType> &hx,
+                                    DType* wx_ptr,
+                                    DType* wh_ptr,
+                                    DType* bx_ptr,
+                                    DType* bh_ptr,
+                                    DType* y_ptr,
+                                    DType* hy_ptr) {
+  DType* ht = y_ptr;
+  DType* ht_1 = y_ptr;
+  DType* back_ht_1 = y_ptr + (T-1) * N * H * D + H;
+  DType* back_ht = back_ht_1;
+  DType* gemmC1  = ws;              // [D, T, N, 3 * H]
+  DType* gemmC2  = gemmC1 + D * T * N * 3 * H;  // N * 3 * H
+  DType* rt = gemmC2 + N * 3 * H;
+  DType* zt = rt + N * H;
+  DType* nt = zt + N * H;
+  DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H;
+  DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H;
+  DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + 3 * H * 2 : NULL;
+  DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + 3 * H * 2: NULL;
+  DType* back_gemmC1 = gemmC1 + T * N * 3 * H;
+  DType* gemmC1_t = gemmC1;
+
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(3, H));
+  const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  if (D == 1) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * H + j] = hx[i][j];
+      }
+  } else {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * D * H + j] = hx[i][j];
+        back_ht_1[i * D * H + j] = hx[N + i][j];
+    }
+  }
+  Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, 3 * H));
+  Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, 3 * H));
+  Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, 3 * H));
+
+  // x * wx.T : [T * N, I] * [I, 3 * H]
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true);
+  if (D == 2) {
+    linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
+  }
+
+  for (int t = 0; t < T; t++) {
+    //  perform the first direction, X * wx and H * wh for each step
+    //  ht-1 * wh, ht-1:[N, H] wh:[3 * H, H]
+    Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
+    if (D == 1) {
+      linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true);
+    } else {
+      Tensor<cpu, 3, DType> dht_1_tmp = Tensor<cpu, 3, 
DType>(reinterpret_cast<DType*>(tmp_buf),
+                                     Shape3(D, H, N));
+      dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true);
+    }
+    gemmC1_t = gemmC1 + t * N * 3 * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int rtb = i * 3 * H;
+        int ztb = i * 3 * H + H;
+        int ntb = i * 3 * H + 2 * H;
+        rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j]
+            + bx[0][j] + bh[0][j]);
+        zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j]
+            + bx[1][j] + bh[1][j]);
+        nt[i * H + j] = tanh(gemmC1_t[ntb + j] + bx[2][j] +
+            rt[i * H + j] * (gemmC2[ntb + j] + bh[2][j]));
+        ht[i * D * H + j] = (1-zt[i * H + j]) * nt[i * H + j] +
+            zt[i * H + j] * ht_1[i * D * H + j];
+      }
+    }
+    ht_1 = ht;
+    ht = ht + D * H * N;
+    //  perform the second direction
+    if (D == 2) {
+      gemmC1_t = back_gemmC1 + (T - 1 - t) * N * 3 * H;
+      Tensor<cpu, 2, DType> dback_ht_1(back_ht_1, Shape2(N, D * H));
+      Tensor<cpu, 3, DType> dback_ht_1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dback_ht_1_tmp[0], back_wh, dgemmC2, alpha, beta, true, 
true);
+
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int rtb = i * 3 * H;
+          int ztb = i * 3 * H + H;
+          int ntb = i * 3 * H + 2 * H;
+          rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] +
+              gemmC2[rtb + j] + back_bx[0][j] + back_bh[0][j]);
+          zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] +
+              gemmC2[ztb + j] + back_bx[1][j]+ back_bh[1][j]);
+          nt[i * H + j] = tanh(gemmC1_t[ntb + j] + back_bx[2][j]
+              + rt[i * H + j] * (gemmC2[ntb + j] + back_bh[2][j]));
+          back_ht[i * D * H + j] = (1 - zt[i * H + j]) * nt[i * H + j]
+              + zt[i * H + j] * back_ht_1[i * D * H + j];
+        }
+      }
+      back_ht_1 = back_ht;
+      back_ht = back_ht - D * H * N;
+    }
+  }
+  //  copy last state to hy, from(N, H * D) to (D, N, H)
+  if (state_outputs) {
+    if (D == 1) {
+      DType* y_start = y_ptr + (T - 1) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * H + j];
+        }
+    } else {
+      DType* y_start = y_ptr + (T - 1) * N * H * D;
+      DType* y_back_start = y_ptr + H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * D * H + j];
+          hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
+        }
+    }
+  }
+}
+
+template <typename DType>
+void GruForwardInference(DType* ws,
+                         bool state_outputs,
+                         const int L,
+                         const int D,
+                         const int T,
+                         const int N,
+                         int I,
+                         const int H,
+                         DType* x_ptr,
+                         DType* hx_ptr,
+                         DType* w_ptr,
+                         DType* y_ptr,
+                         DType* hy_ptr) {
+  DType* wx = w_ptr;
+  DType* wh = wx + I * H * 3;
+  DType* bx = wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3)
+      + (L - 1) * ((D + 1) * H) * H * 3 * D;
+  DType* bh = bx + H * 3;
+
+  DType* y_tmp = ws;
+  DType* y_l = x_ptr;
+  DType* tmp_buf = y_tmp + D * T * N * H;
+  DType* ws2 = y_tmp + D * T * N * H + D * H * N;
+
+  DType* wx_l = wx;
+  DType* wh_l = wh;
+  DType* bx_l = bx;
+  DType* bh_l = bh;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H));
+  DType* hy_l = hy_ptr;
+  for (int l = 0; l < L; l++) {
+    Tensor<cpu, 2, DType> x_l(y_l, Shape2(T * N, I));
+    if ((L + l) % 2) {
+      y_l = y_ptr;
+    } else {
+      y_l = y_tmp;
+    }
+    Tensor<cpu, 2, DType> hx_l = hx[D * l];
+    GruForwardInferenceSingleLayer<DType>(ws2, tmp_buf, state_outputs, D, T, 
N, I, H,
+                                        x_l, hx_l, wx_l, wh_l, bx_l, bh_l, 
y_l, hy_l);
+    hy_l = hy_l + D * N * H;
+    bx_l = bx_l + 3 * H * D * 2;
+    bh_l = bh_l + 3 * H * D * 2;
+    wx_l = wx_l + I * H * 3 * D + H * H * 3 * D;
+    if (l == 0) {
+      I = D * H;
+    }
+    wh_l = wx_l + I * 3 * H;
+  }
+}
+
+
+template<typename DType>
+void GruForwardTrainingSingleLayer(DType* ws,
+                                   DType* tmp_buf,
+                                   bool state_outputs,
+                                   const int D,
+                                   const int T,
+                                   const int N,
+                                   const int I,
+                                   const int H,
+                                   const Tensor<cpu, 2, DType> &x,
+                                   const Tensor<cpu, 2, DType> &hx,
+                                   DType* wx_ptr,
+                                   DType* wh_ptr,
+                                   DType* bx_ptr,
+                                   DType* bh_ptr,
+                                   DType* gateR,
+                                   DType* gateZ,
+                                   DType* gateN,
+                                   DType* Mnh,
+                                   DType* y_ptr,
+                                   DType* hy_ptr) {
+  DType* ht = y_ptr;
+  DType* ht_1 = y_ptr;
+  DType* back_ht_1 = y_ptr + (T - 1)* N * H * D + H;
+  DType* back_ht = back_ht_1;
+
+  DType* gemmC1  = ws;              // [D, T, N, 3 * H]
+  DType* gemmC2  = gemmC1 + D * T * N * 3 * H;  // N * 3 * H
+  DType* rt = gateR;
+  DType* zt = gateZ;
+  DType* nt = gateN;
+  DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H;
+  DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H;
+  DType* back_bx_ptr = (bx_ptr != NULL)? bx_ptr + 3 * H * 2 : NULL;
+  DType* back_bh_ptr = (bh_ptr != NULL)? bh_ptr + 3 * H * 2 : NULL;
+  DType* back_gateR = gateR + T * N * H;
+  DType* back_gateZ = gateZ + T * N * H;
+  DType* back_gateN = gateN + T * N * H;
+  DType* back_Mnh = Mnh + T * N * H;
+  DType* back_gemmC1 = gemmC1 + T * N * 3 * H;
+  DType* gemmC1_t = gemmC1;
+
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> bx(bx_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> bh(bh_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> back_bx(back_bx_ptr, Shape2(3, H));
+  const Tensor<cpu, 2, DType> back_bh(back_bh_ptr, Shape2(3, H));
+  const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  if (D == 1) {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * H + j] = hx[i][j];
+      }
+  } else {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; i++)
+      for (int j = 0; j < H; j++) {
+        y_ptr[i * D * H + j] = hx[i][j];
+        back_ht_1[i * D * H + j] = hx[N + i][j];
+    }
+  }
+
+  Tensor<cpu, 2, DType> dgemmC1(ws, Shape2(T * N, 3 * H));
+  Tensor<cpu, 2, DType> dgemmC2(gemmC2, Shape2(N, 3 * H));
+  Tensor<cpu, 2, DType> dback_gemmC1(back_gemmC1, Shape2(T * N, 3 * H));
+
+  // x * wx.T : [T * N, I] * [I, 3 * H]
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  linalg_gemm(x, wx, dgemmC1, alpha, beta, false, true);
+  if (D == 2) {
+    linalg_gemm(x, back_wx, dback_gemmC1, alpha, beta, false, true);
+  }
+
+  for (int t = 0; t < T; t++) {
+    //  perform the first direction, X * wx and H * wh for each step
+    //  ht-1 * wh, ht-1:[N, H] wh:[3 * H, H]
+    Tensor<cpu, 2, DType> dht_1(ht_1, Shape2(N, D * H));
+    if (D == 1) {
+      linalg_gemm(dht_1, wh, dgemmC2, alpha, beta, false, true);
+    } else {
+      Tensor<cpu, 3, DType> dht_1_tmp = Tensor<cpu, 3, 
DType>(reinterpret_cast<DType*>(tmp_buf),
+                                     Shape3(D, H, N));
+      dht_1_tmp = reshape(dht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dht_1_tmp[0], wh, dgemmC2, alpha, beta, true, true);
+    }
+    gemmC1_t = gemmC1 + t * N * 3 * H;
+
+    rt = gateR + t * N * H;
+    zt = gateZ + t * N * H;
+    nt = gateN + t * N * H;
+    gemmC1_t = gemmC1 + t * N * 3 * H;
+    DType* Mnht = Mnh + t * N * H;
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < N; ++i) {
+      for (int j = 0; j < H; ++j) {
+        int rtb = i * 3 * H;
+        int ztb = i * 3 * H + H;
+        int ntb = i * 3 * H + 2 * H;
+        Mnht[i * H + j] = gemmC2[ntb + j] + bh[2][j];
+        rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] + gemmC2[rtb + j]
+            + bx[0][j] + bh[0][j]);
+        zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] + gemmC2[ztb + j]
+            + bx[1][j] + bh[1][j]);
+        nt[i * H + j] = tanh(gemmC1_t[ntb + j] + bx[2][j] +
+            rt[i * H + j] * (gemmC2[ntb + j] + bh[2][j]));
+        ht[i * D * H + j] = (1-zt[i * H + j]) * nt[i * H + j] +
+            zt[i * H + j] * ht_1[i * D * H + j];
+      }
+    }
+    ht_1 = ht;
+    ht = ht + D * H * N;
+    //  perform the second direction
+    if (D == 2) {
+      rt = back_gateR + (T - 1 - t) * N * H;
+      zt = back_gateZ + (T - 1 - t) * N * H;
+      nt = back_gateN + (T - 1 - t) * N * H;
+      gemmC1_t = back_gemmC1 + (T - 1 - t) * N * 3 * H;
+      Tensor<cpu, 2, DType> dback_ht_1(back_ht_1, Shape2(N, D * H));
+      Tensor<cpu, 3, DType> dback_ht_1_tmp = Tensor<cpu, 3, DType>
+          (reinterpret_cast<DType*>(tmp_buf), Shape3(D, H, N));
+      dback_ht_1_tmp = reshape(dback_ht_1.T(), Shape3(D, H, N));
+      linalg_gemm(dback_ht_1_tmp[0], back_wh, dgemmC2, alpha, beta, true, 
true);
+
+      DType* back_Mnht = back_Mnh + (T - 1 - t) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; ++i) {
+        for (int j = 0; j < H; ++j) {
+          int rtb = i * 3 * H;
+          int ztb = i * 3 * H + H;
+          int ntb = i * 3 * H + 2 * H;
+          back_Mnht[i * H + j] = gemmC2[ntb + j] + back_bh[2][j];
+          rt[i * H + j] = sigmoid(gemmC1_t[rtb + j] +
+              gemmC2[rtb + j] + back_bx[0][j] + back_bh[0][j]);
+          zt[i * H + j] = sigmoid(gemmC1_t[ztb + j] +
+              gemmC2[ztb + j] + back_bx[1][j] + back_bh[1][j]);
+          nt[i * H + j] = tanh(gemmC1_t[ntb + j] + back_bx[2][j]
+              + rt[i * H + j] * (gemmC2[ntb + j] + back_bh[2][j]));
+          back_ht[i * D * H + j] = (1 - zt[i * H + j]) * nt[i * H + j]
+              + zt[i * H + j] * back_ht_1[i * D * H + j];
+        }
+      }
+      back_ht_1 = back_ht;
+      back_ht = back_ht - D * H * N;
+    }
+  }
+
+  //  copy last state to hy, from(N, H * D) to (D, N, H)
+  if (state_outputs) {
+    if (D == 1) {
+      DType* y_start = y_ptr + (T - 1) * N * H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * H + j];
+        }
+    } else {
+      DType* y_start = y_ptr + (T - 1) * N * H * D;
+      DType* y_back_start = y_ptr + H;
+      #pragma omp parallel for num_threads(omp_threads)
+      for (int i = 0; i < N; i++)
+        for (int j = 0; j < H; j++) {
+          hy_ptr[i * H + j] = y_start[i * D * H + j];
+          hy_ptr[N * H + i * H + j] = y_back_start[i * D * H + j];
+        }
+    }
+  }
+}
+
+template <typename DType>
+void GruForwardTraining(DType* ws,
+                        DType* rs,
+                        bool state_outputs,
+                        const int L,
+                        const int D,
+                        const int T,
+                        const int N,
+                        int I,
+                        const int H,
+                        DType* x_ptr,
+                        DType* hx_ptr,
+                        DType* w_ptr,
+                        DType* y_ptr,
+                        DType* hy_ptr) {
+  DType* wx = w_ptr;
+  DType* wh = wx + I * H * 3;
+  DType* bx = wh + H * H * 3 + (D - 1) * (H * H * 3 + I * H * 3)
+      + (L - 1) * ((D + 1) * H) * H * 3 * D;
+  DType* bh = bx + H * 3;
+  Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(D * L, N, H));
+  DType* hy_l = hy_ptr;
+  DType* gateR_l = rs;
+  DType* gateZ_l = gateR_l + L * T * D * N * H;
+  DType* gateN_l = gateZ_l + L * T * D * N * H;
+  DType* y_l = gateN_l + L * T * D * N * H;
+  DType* Mnh_l = y_l + L * T * N * H * D;
+  DType* tmp_buf = Mnh_l + L * D * T * N * H;
+  DType* ws2 = Mnh_l + L * D * T * N * H + D * H * N;
+  DType* wx_l = wx;
+  DType* wh_l = wh;
+  DType* bx_l = bx;
+  DType* bh_l = bh;
+  DType* y_tmp = x_ptr;
+
+  for (int l = 0; l < L; l++) {
+    if (l != 0) {
+      y_tmp = y_l;
+      y_l = y_l + T * N * H * D;
+    }
+    Tensor<cpu, 2, DType> x_l(y_tmp, Shape2(T * N, I));
+    Tensor<cpu, 2, DType> hx_l = hx[D * l];
+    GruForwardTrainingSingleLayer<DType>(ws2, tmp_buf, state_outputs, D, T, N, 
I, H,
+                                         x_l, hx_l, wx_l, wh_l, bx_l, bh_l,
+                                         gateR_l, gateZ_l, gateN_l, Mnh_l, 
y_l, hy_l);
+    gateR_l = gateR_l + T * D * N * H;
+    gateZ_l = gateZ_l + T * D * N * H;
+    gateN_l = gateN_l + T * D * N * H;
+    Mnh_l = Mnh_l +  T * D * N * H;
+    hy_l = hy_l + D * N * H;
+    bx_l = bx_l + 3 * H * D * 2;
+    bh_l = bh_l + 3 * H * D * 2;
+
+    wx_l = wx_l + I * H * 3 * D + H * H * 3 * D;
+    if (l == 0) {
+      I = D * H;
+    }
+    wh_l = wx_l + I * 3 * H;
+  }
+  memcpy(y_ptr, y_l, T * N * H * D * sizeof(DType));
+}
+
+template <typename DType>
+void GruBackwardSingleLayer(DType* ws,
+                            DType* tmp_buf,
+                            const int D,
+                            const int T,
+                            const int N,
+                            const int I,
+                            const int H,
+                            const Tensor<cpu, 2, DType> &x,
+                            const Tensor<cpu, 2, DType> &hx,
+                            DType* wx_ptr,
+                            DType* wh_ptr,
+                            DType* y_ptr,
+                            DType* dy_ptr,
+                            DType* dhy_ptr,
+                            DType* gateR,
+                            DType* gateZ,
+                            DType* gateN,
+                            DType* Mnh,
+                            DType* dx,
+                            DType* dhx,
+                            DType* dwx,
+                            DType* dwh,
+                            DType* dbx,
+                            DType* dbh) {
+  DType* dyt;
+  DType* ht1;  // [N, D, H]
+  DType* rt;
+  DType* zt;
+  DType* nt;
+  DType* dat;
+  DType* dart;
+  DType* dar = ws;  // [T, N, 3 * H]
+  DType* da = dar + T * N * 3 * H;  // [T, N, 3 * H]
+  DType* dht1 = da + T * N * 3 * H;  // [D, N, H]
+  DType* hx_ = dht1 + D * N * H;  // [N, D, H]
+  DType* Mnht = Mnh;
+
+  DType* back_ht1;
+  DType* back_dht1 = dht1 + N * H;  // [N, H]
+  DType* back_Mnht = Mnh + T * N * H;
+  DType* back_gateR = gateR + T * N * H;
+  DType* back_gateZ = gateZ + T * N * H;
+  DType* back_gateN = gateN + T * N * H;
+  DType* back_wx_ptr = wx_ptr + I * 3 * H + H * 3 * H;
+  DType* back_wh_ptr = wh_ptr + I * 3 * H + H * 3 * H;
+  DType* back_dwx = dwx + I * 3 * H + H * 3 * H;
+  DType* back_dwh = dwh + I * 3 * H + H * 3 * H;
+  DType* back_dbx = dbx + 3 * H * 2;
+  DType* back_dbh = dbh + 3 * H * 2;
+
+  DType alpha = 1.0;
+  DType beta = 0.0;
+  const Tensor<cpu, 2, DType> wx(wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> wh(wh_ptr, Shape2(H * 3, H));
+  const Tensor<cpu, 2, DType> back_wx(back_wx_ptr, Shape2(H * 3, I));
+  const Tensor<cpu, 2, DType> back_wh(back_wh_ptr, Shape2(H * 3, H));
+  const int omp_threads = 
mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < D * H * 3 * H; ++i) {
+    dwh[i] = 0;
 
 Review comment:
   This is still overwriting gradient even when req[kParams] = kAddTo ?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to