D-Roberts commented on a change in pull request #18197:
URL: https://github.com/apache/incubator-mxnet/pull/18197#discussion_r425903506
##
File path: src/operator/numpy/linalg/np_qr-inl.h
##
@@ -542,36 +591,119 @@ void QrBackwardImpl(const TBlob& grad_a,
const nnvm::NodeAttrs& attrs) {
Stream *s = ctx.get_stream();
const mxnet::TShape& a_shape = a.shape_;
+ const mxnet::TShape& q_shape = q.shape_;
const mxnet::TShape& r_shape = r.shape_;
const int a_ndim = a_shape.ndim();
+ const int m = a.size(a_ndim - 2);
const int n = a.size(a_ndim - 1);
if (kNullOp == req[0]) { return; }
if (0U == a_shape.Size()) { return; }
MSHADOW_SGL_DBL_TYPE_SWITCH(grad_a.type_flag_, DType, {
-// case m >= n; Q of same shape with A and R is (n, n)
-DType *m_ptr = reinterpret_cast(workspace.dptr_);
-DType *grad_a_ptr = m_ptr + r_shape.Size();
-TBlob temp_m(m_ptr, r_shape, xpu::kDevMask);
+// common for all shapes (m, n)
+DType *grad_a_ptr = reinterpret_cast(workspace.dptr_);
TBlob grad_a_data(grad_a_ptr, a_shape, xpu::kDevMask);
-// dR_T
-mxnet_op::Kernel::Launch(
- s, r_shape.Size(), grad_r.dptr(), m_ptr, n, n, n * n);
-
-qr_backward::op(grad_a_data.FlatToKD(s),
-grad_q.FlatToKD(s),
-grad_r.FlatToKD(s),
-a.FlatToKD(s),
-q.FlatToKD(s),
-r.FlatToKD(s),
-temp_m.FlatToKD(s),
-ctx, attrs);
-
+if (m >= n) {
+ // Q of same shape with A (m, n) and R is (n, n)
+ DType *m_ptr = grad_a_ptr + a_shape.Size();
+ TBlob temp_m(m_ptr, r_shape, xpu::kDevMask);
+ // dR_T
+ mxnet_op::Kernel::Launch(
+s, r_shape.Size(), grad_r.dptr(), m_ptr, n, n, n * n);
+ qr_backward::op(grad_a_data.FlatToKD(s),
+ grad_q.FlatToKD(s),
+ grad_r.FlatToKD(s),
+ q.FlatToKD(s),
+ r.FlatToKD(s),
+ temp_m.FlatToKD(s),
+ ctx, attrs);
+} else {
+ // R is same shape with A (m, n) and Q is (m, m)
+ // Partition A = (X | Y); R = (U | V)
+ // X and U are (m, m); Y and V are (m, n - m)
+ mxnet::TShape v_shape(q_shape);
+ v_shape[a_ndim - 1] = n - m;
+
+ DType *m_ptr = grad_a_ptr + a_shape.Size();
+ DType *u_ptr = m_ptr + q_shape.Size();
+ DType *dq_prime_ptr = u_ptr + q_shape.Size();
+ DType *dv_ptr = dq_prime_ptr + q_shape.Size();
+ DType *y_ptr = dv_ptr + v_shape.Size();
+ DType *du_ptr = y_ptr + v_shape.Size();
+ DType *dx_ptr = du_ptr + q_shape.Size();
+ DType *dy_ptr = dx_ptr + q_shape.Size();
+
+ TBlob temp_m(m_ptr, q_shape, xpu::kDevMask);
+ TBlob u_data(u_ptr, q_shape, xpu::kDevMask);
+ TBlob dq_prime_data(dq_prime_ptr, q_shape, xpu::kDevMask);
+ TBlob dv_data(dv_ptr, v_shape, xpu::kDevMask);
+ TBlob y_data(y_ptr, v_shape, xpu::kDevMask);
+ TBlob du_data(du_ptr, q_shape, xpu::kDevMask);
+ TBlob dx_data(dx_ptr, q_shape, xpu::kDevMask);
+ TBlob dy_data(dy_ptr, v_shape, xpu::kDevMask);
+
+ Tensor R = r.FlatToKD(s);
+ Tensor dR = grad_r.FlatToKD(s);
+ Tensor Q = q.FlatToKD(s);
+ Tensor dQ = grad_q.FlatToKD(s);
+ Tensor dQ_prime = dq_prime_data.FlatToKD(s);
+ Tensor A = a.FlatToKD(s);
+ Tensor dA = grad_a_data.FlatToKD(s);
+ Tensor U = u_data.FlatToKD(s);
+ Tensor dU = du_data.FlatToKD(s);
+ Tensor dV = dv_data.FlatToKD(s);
+ Tensor Y = y_data.FlatToKD(s);
+ Tensor dX = dx_data.FlatToKD(s);
+ Tensor dY = dy_data.FlatToKD(s);
+ Tensor M = temp_m.FlatToKD(s);
+
+ // U
+ for (index_t i = 0; i < R.size(0); ++i) {
+const Tensor& Ri = R[i];
+const Tensor& Ui = U[i];
+Tensor Um(Ri.dptr_, Shape2(m, m), Ri.stride_, s);
+Copy(Ui, Um, s);
+ }
+ // dU
+ for (index_t i = 0; i < dR.size(0); ++i) {
+const Tensor& dRi = dR[i];
+const Tensor& dUi = dU[i];
+Tensor dUm(dRi.dptr_, Shape2(m, m), dRi.stride_, s);
+Copy(dUi, dUm, s);
+ }
+ // Y
+ mxnet_op::Kernel::Launch(
+s, A.size(0), m, n, A.dptr_, A.stride_, Y.dptr_, Y.stride_);
+ // dV
+ mxnet_op::Kernel::Launch(
+s, dR.size(0), m, n, dR.dptr_, dR.stride_, dV.dptr_, dV.stride_);
+ // store dU_T in M
+ mxnet_op::Kernel::Launch(
+s, q_shape.Size(), dU.dptr_, m_ptr, m, m, m * m);
+ // dq_prime = dQ
+ Copy(dQ_prime, dQ, s);
+ // dq_prime = dQ+Y@dV.T
+ gemm::op(Y, dV, dQ_prime, DType(1.0), DType(1.0), false, true, s);
+ // dX = op call
+ qr_backward::op(dX,
+ dQ_prime,
+ dU,
+ Q,
+ U,
+ M,
+ ctx, attrs);
+ // dY = Q@dV
+ gemm::op(Q, dV, dY,