D-Roberts commented on a change in pull request #18197:
URL: https://github.com/apache/incubator-mxnet/pull/18197#discussion_r425903506



##########
File path: src/operator/numpy/linalg/np_qr-inl.h
##########
@@ -542,36 +591,119 @@ void QrBackwardImpl(const TBlob& grad_a,
                     const nnvm::NodeAttrs& attrs) {
   Stream<xpu> *s = ctx.get_stream<xpu>();
   const mxnet::TShape& a_shape = a.shape_;
+  const mxnet::TShape& q_shape = q.shape_;
   const mxnet::TShape& r_shape = r.shape_;
   const int a_ndim = a_shape.ndim();
+  const int m = a.size(a_ndim - 2);
   const int n = a.size(a_ndim - 1);
 
   if (kNullOp == req[0]) { return; }
 
   if (0U == a_shape.Size()) { return; }
 
   MSHADOW_SGL_DBL_TYPE_SWITCH(grad_a.type_flag_, DType, {
-    // case m >= n; Q of same shape with A and R is (n, n)
-    DType *m_ptr = reinterpret_cast<DType*>(workspace.dptr_);
-    DType *grad_a_ptr = m_ptr + r_shape.Size();
-    TBlob temp_m(m_ptr, r_shape, xpu::kDevMask);
+    // common for all shapes (m, n)
+    DType *grad_a_ptr = reinterpret_cast<DType*>(workspace.dptr_);
     TBlob grad_a_data(grad_a_ptr, a_shape, xpu::kDevMask);
-    // dR_T
-    mxnet_op::Kernel<QrTypeTransposeHelper, xpu>::Launch(
-      s, r_shape.Size(), grad_r.dptr<DType>(), m_ptr, n, n, n * n);
-
-    qr_backward::op(grad_a_data.FlatToKD<xpu, 3, DType>(s),
-                    grad_q.FlatToKD<xpu, 3, DType>(s),
-                    grad_r.FlatToKD<xpu, 3, DType>(s),
-                    a.FlatToKD<xpu, 3, DType>(s),
-                    q.FlatToKD<xpu, 3, DType>(s),
-                    r.FlatToKD<xpu, 3, DType>(s),
-                    temp_m.FlatToKD<xpu, 3, DType>(s),
-                    ctx, attrs);
-
+    if (m >= n) {
+      // Q of same shape with A (m, n) and R is (n, n)
+      DType *m_ptr = grad_a_ptr + a_shape.Size();
+      TBlob temp_m(m_ptr, r_shape, xpu::kDevMask);
+      // dR_T
+      mxnet_op::Kernel<QrTypeTransposeHelper, xpu>::Launch(
+        s, r_shape.Size(), grad_r.dptr<DType>(), m_ptr, n, n, n * n);
+      qr_backward::op(grad_a_data.FlatToKD<xpu, 3, DType>(s),
+                      grad_q.FlatToKD<xpu, 3, DType>(s),
+                      grad_r.FlatToKD<xpu, 3, DType>(s),
+                      q.FlatToKD<xpu, 3, DType>(s),
+                      r.FlatToKD<xpu, 3, DType>(s),
+                      temp_m.FlatToKD<xpu, 3, DType>(s),
+                      ctx, attrs);
+    } else {
+      // R is same shape with A (m, n) and Q is (m, m)
+      // Partition A = (X | Y); R = (U | V)
+      // X and U are (m, m); Y and V are (m, n - m)
+      mxnet::TShape v_shape(q_shape);
+      v_shape[a_ndim - 1] = n - m;
+
+      DType *m_ptr = grad_a_ptr + a_shape.Size();
+      DType *u_ptr = m_ptr + q_shape.Size();
+      DType *dq_prime_ptr = u_ptr + q_shape.Size();
+      DType *dv_ptr = dq_prime_ptr + q_shape.Size();
+      DType *y_ptr = dv_ptr + v_shape.Size();
+      DType *du_ptr = y_ptr + v_shape.Size();
+      DType *dx_ptr = du_ptr + q_shape.Size();
+      DType *dy_ptr = dx_ptr + q_shape.Size();
+
+      TBlob temp_m(m_ptr, q_shape, xpu::kDevMask);
+      TBlob u_data(u_ptr, q_shape, xpu::kDevMask);
+      TBlob dq_prime_data(dq_prime_ptr, q_shape, xpu::kDevMask);
+      TBlob dv_data(dv_ptr, v_shape, xpu::kDevMask);
+      TBlob y_data(y_ptr, v_shape, xpu::kDevMask);
+      TBlob du_data(du_ptr, q_shape, xpu::kDevMask);
+      TBlob dx_data(dx_ptr, q_shape, xpu::kDevMask);
+      TBlob dy_data(dy_ptr, v_shape, xpu::kDevMask);
+
+      Tensor<xpu, 3, DType> R = r.FlatToKD<xpu, 3, DType>(s);
+      Tensor<xpu, 3, DType> dR = grad_r.FlatToKD<xpu, 3, DType>(s);
+      Tensor<xpu, 3, DType> Q = q.FlatToKD<xpu, 3, DType>(s);
+      Tensor<xpu, 3, DType> dQ = grad_q.FlatToKD<xpu, 3, DType>(s);
+      Tensor<xpu, 3, DType> dQ_prime = dq_prime_data.FlatToKD<xpu, 3, 
DType>(s);
+      Tensor<xpu, 3, DType> A = a.FlatToKD<xpu, 3, DType>(s);
+      Tensor<xpu, 3, DType> dA = grad_a_data.FlatToKD<xpu, 3, DType>(s);
+      Tensor<xpu, 3, DType> U = u_data.FlatToKD<xpu, 3, DType>(s);
+      Tensor<xpu, 3, DType> dU = du_data.FlatToKD<xpu, 3, DType>(s);
+      Tensor<xpu, 3, DType> dV = dv_data.FlatToKD<xpu, 3, DType>(s);
+      Tensor<xpu, 3, DType> Y = y_data.FlatToKD<xpu, 3, DType>(s);
+      Tensor<xpu, 3, DType> dX = dx_data.FlatToKD<xpu, 3, DType>(s);
+      Tensor<xpu, 3, DType> dY = dy_data.FlatToKD<xpu, 3, DType>(s);
+      Tensor<xpu, 3, DType> M = temp_m.FlatToKD<xpu, 3, DType>(s);
+
+      // U
+      for (index_t i = 0; i < R.size(0); ++i) {
+        const Tensor<xpu, 2, DType>& Ri = R[i];
+        const Tensor<xpu, 2, DType>& Ui = U[i];
+        Tensor<xpu, 2, DType> Um(Ri.dptr_, Shape2(m, m), Ri.stride_, s);
+        Copy(Ui, Um, s);
+      }
+      // dU
+      for (index_t i = 0; i < dR.size(0); ++i) {
+        const Tensor<xpu, 2, DType>& dRi = dR[i];
+        const Tensor<xpu, 2, DType>& dUi = dU[i];
+        Tensor<xpu, 2, DType> dUm(dRi.dptr_, Shape2(m, m), dRi.stride_, s);
+        Copy(dUi, dUm, s);
+      }
+      // Y
+      mxnet_op::Kernel<QrBackHelper_G1, xpu>::Launch(
+        s, A.size(0), m, n, A.dptr_, A.stride_, Y.dptr_, Y.stride_);
+      // dV
+      mxnet_op::Kernel<QrBackHelper_G1, xpu>::Launch(
+        s, dR.size(0), m, n, dR.dptr_, dR.stride_, dV.dptr_, dV.stride_);
+      // store dU_T in M
+      mxnet_op::Kernel<QrTypeTransposeHelper, xpu>::Launch(
+        s, q_shape.Size(), dU.dptr_, m_ptr, m, m, m * m);
+      // dq_prime = dQ
+      Copy(dQ_prime, dQ, s);
+      // dq_prime = dQ+Y@dV.T
+      gemm::op(Y, dV, dQ_prime, DType(1.0), DType(1.0), false, true, s);
+      // dX = op call
+      qr_backward::op(dX,
+                      dQ_prime,
+                      dU,
+                      Q,
+                      U,
+                      M,
+                      ctx, attrs);
+      // dY = Q@dV
+      gemm::op(Q, dV, dY, DType(1.0), DType(0.0), false, false, s);

Review comment:
       Yes, absolutely, I'd missed that one. Done!




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to