[GitHub] [incubator-mxnet] D-Roberts commented on a change in pull request #18197: [Numpy] Add qr backward part 2 for wide matrices with m < n

2020-05-15 Thread GitBox


D-Roberts commented on a change in pull request #18197:
URL: https://github.com/apache/incubator-mxnet/pull/18197#discussion_r425903506



##
File path: src/operator/numpy/linalg/np_qr-inl.h
##
@@ -542,36 +591,119 @@ void QrBackwardImpl(const TBlob& grad_a,
 const nnvm::NodeAttrs& attrs) {
   Stream *s = ctx.get_stream();
   const mxnet::TShape& a_shape = a.shape_;
+  const mxnet::TShape& q_shape = q.shape_;
   const mxnet::TShape& r_shape = r.shape_;
   const int a_ndim = a_shape.ndim();
+  const int m = a.size(a_ndim - 2);
   const int n = a.size(a_ndim - 1);
 
   if (kNullOp == req[0]) { return; }
 
   if (0U == a_shape.Size()) { return; }
 
   MSHADOW_SGL_DBL_TYPE_SWITCH(grad_a.type_flag_, DType, {
-// case m >= n; Q of same shape with A and R is (n, n)
-DType *m_ptr = reinterpret_cast(workspace.dptr_);
-DType *grad_a_ptr = m_ptr + r_shape.Size();
-TBlob temp_m(m_ptr, r_shape, xpu::kDevMask);
+// common for all shapes (m, n)
+DType *grad_a_ptr = reinterpret_cast(workspace.dptr_);
 TBlob grad_a_data(grad_a_ptr, a_shape, xpu::kDevMask);
-// dR_T
-mxnet_op::Kernel::Launch(
-  s, r_shape.Size(), grad_r.dptr(), m_ptr, n, n, n * n);
-
-qr_backward::op(grad_a_data.FlatToKD(s),
-grad_q.FlatToKD(s),
-grad_r.FlatToKD(s),
-a.FlatToKD(s),
-q.FlatToKD(s),
-r.FlatToKD(s),
-temp_m.FlatToKD(s),
-ctx, attrs);
-
+if (m >= n) {
+  // Q of same shape with A (m, n) and R is (n, n)
+  DType *m_ptr = grad_a_ptr + a_shape.Size();
+  TBlob temp_m(m_ptr, r_shape, xpu::kDevMask);
+  // dR_T
+  mxnet_op::Kernel::Launch(
+s, r_shape.Size(), grad_r.dptr(), m_ptr, n, n, n * n);
+  qr_backward::op(grad_a_data.FlatToKD(s),
+  grad_q.FlatToKD(s),
+  grad_r.FlatToKD(s),
+  q.FlatToKD(s),
+  r.FlatToKD(s),
+  temp_m.FlatToKD(s),
+  ctx, attrs);
+} else {
+  // R is same shape with A (m, n) and Q is (m, m)
+  // Partition A = (X | Y); R = (U | V)
+  // X and U are (m, m); Y and V are (m, n - m)
+  mxnet::TShape v_shape(q_shape);
+  v_shape[a_ndim - 1] = n - m;
+
+  DType *m_ptr = grad_a_ptr + a_shape.Size();
+  DType *u_ptr = m_ptr + q_shape.Size();
+  DType *dq_prime_ptr = u_ptr + q_shape.Size();
+  DType *dv_ptr = dq_prime_ptr + q_shape.Size();
+  DType *y_ptr = dv_ptr + v_shape.Size();
+  DType *du_ptr = y_ptr + v_shape.Size();
+  DType *dx_ptr = du_ptr + q_shape.Size();
+  DType *dy_ptr = dx_ptr + q_shape.Size();
+
+  TBlob temp_m(m_ptr, q_shape, xpu::kDevMask);
+  TBlob u_data(u_ptr, q_shape, xpu::kDevMask);
+  TBlob dq_prime_data(dq_prime_ptr, q_shape, xpu::kDevMask);
+  TBlob dv_data(dv_ptr, v_shape, xpu::kDevMask);
+  TBlob y_data(y_ptr, v_shape, xpu::kDevMask);
+  TBlob du_data(du_ptr, q_shape, xpu::kDevMask);
+  TBlob dx_data(dx_ptr, q_shape, xpu::kDevMask);
+  TBlob dy_data(dy_ptr, v_shape, xpu::kDevMask);
+
+  Tensor R = r.FlatToKD(s);
+  Tensor dR = grad_r.FlatToKD(s);
+  Tensor Q = q.FlatToKD(s);
+  Tensor dQ = grad_q.FlatToKD(s);
+  Tensor dQ_prime = dq_prime_data.FlatToKD(s);
+  Tensor A = a.FlatToKD(s);
+  Tensor dA = grad_a_data.FlatToKD(s);
+  Tensor U = u_data.FlatToKD(s);
+  Tensor dU = du_data.FlatToKD(s);
+  Tensor dV = dv_data.FlatToKD(s);
+  Tensor Y = y_data.FlatToKD(s);
+  Tensor dX = dx_data.FlatToKD(s);
+  Tensor dY = dy_data.FlatToKD(s);
+  Tensor M = temp_m.FlatToKD(s);
+
+  // U
+  for (index_t i = 0; i < R.size(0); ++i) {
+const Tensor& Ri = R[i];
+const Tensor& Ui = U[i];
+Tensor Um(Ri.dptr_, Shape2(m, m), Ri.stride_, s);
+Copy(Ui, Um, s);
+  }
+  // dU
+  for (index_t i = 0; i < dR.size(0); ++i) {
+const Tensor& dRi = dR[i];
+const Tensor& dUi = dU[i];
+Tensor dUm(dRi.dptr_, Shape2(m, m), dRi.stride_, s);
+Copy(dUi, dUm, s);
+  }
+  // Y
+  mxnet_op::Kernel::Launch(
+s, A.size(0), m, n, A.dptr_, A.stride_, Y.dptr_, Y.stride_);
+  // dV
+  mxnet_op::Kernel::Launch(
+s, dR.size(0), m, n, dR.dptr_, dR.stride_, dV.dptr_, dV.stride_);
+  // store dU_T in M
+  mxnet_op::Kernel::Launch(
+s, q_shape.Size(), dU.dptr_, m_ptr, m, m, m * m);
+  // dq_prime = dQ
+  Copy(dQ_prime, dQ, s);
+  // dq_prime = dQ+Y@dV.T
+  gemm::op(Y, dV, dQ_prime, DType(1.0), DType(1.0), false, true, s);
+  // dX = op call
+  qr_backward::op(dX,
+  dQ_prime,
+  dU,
+  Q,
+  U,
+  M,
+  ctx, attrs);
+  // dY = Q@dV
+  gemm::op(Q, dV, dY, 

[GitHub] [incubator-mxnet] D-Roberts commented on a change in pull request #18197: [Numpy] Add qr backward part 2 for wide matrices with m < n

2020-05-14 Thread GitBox


D-Roberts commented on a change in pull request #18197:
URL: https://github.com/apache/incubator-mxnet/pull/18197#discussion_r425140133



##
File path: src/operator/numpy/linalg/np_qr-inl.h
##
@@ -514,15 +548,28 @@ struct qr_backward {
 
 template
 size_t QrBackwardWorkspaceSize(const TBlob& a,
+   const TBlob& q,
const TBlob& r,
const TBlob& grad_a) {
+  const mxnet::TShape& a_shape = a.shape_;
+  const int a_ndim = a_shape.ndim();
+  const int n = a.size(a_ndim - 1);
+  const int m = a.size(a_ndim - 2);
+
   if (0U == a.Size()) { return 0U; }
 
   MSHADOW_SGL_DBL_TYPE_SWITCH(grad_a.type_flag_, DType, {
 size_t work_space_size = 0;
-// for grad a and M
 work_space_size += a.Size();
-work_space_size += r.Size();
+if (m >= n) {
+  work_space_size += r.Size();
+} else {
+  const mxnet::TShape& q_shape = q.shape_;
+  mxnet::TShape v_shape(q_shape);
+  v_shape[a_ndim - 1] = n - m;
+  work_space_size += 5 * q.Size();

Review comment:
   @hzfan Thank you for your review. Added comments and updated reference 
in the code.





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org