arcadiaphy commented on a change in pull request #15007: Add matrix determinant 
operator in linalg
URL: https://github.com/apache/incubator-mxnet/pull/15007#discussion_r287316084
 
 

 ##########
 File path: src/operator/tensor/la_op-inl.h
 ##########
 @@ -458,14 +458,90 @@ struct inverse {
   template<typename xpu, typename DType>
   static void op(const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>& 
A,
                  const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    // Reserve workspace (size determined by query)
-    int lwork(linalg_getri_workspace_query(A, s));
-    Tensor<xpu, 1, DType> work = ctx.requested[0]
-      .get_space_typed<xpu, 1, DType>(Shape1(lwork), s);
     // Since inverse(A) = trans(inverse(trans(A))), so we don't need to 
transpose
     // A even if we are using the col-major version of getrf and getri 
routines.
-    linalg_batch_inverse(A, B, work, s);
+    linalg_batch_inverse(A, B, ctx);
+  }
+};
+
+// partial pivoting LU decomposition: A = PLU, so det(A) = det(P)det(L)det(U)
+// det(P) depends on number of row changes in P, det(L) = 1, det(U) = 
prod(diag(U))
+// this kernel computes sign(det(A)), log(abs(det(A)))
+struct SignedLogDet {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, int N, int* pivot,
+                                  DType *LU, DType* sign, DType *logdet) {
+    int changes(0);
+    DType diag_sign(1);
+    DType diag_logsum(0);
+    int *pivot_mat = pivot + i * N;
+    DType *LU_mat = LU + i * N * N;
+    for (int j = 0; j < N; ++j) {
+      changes += (pivot_mat[j] != (j + 1));
+      DType diag = LU_mat[j * (N + 1)];
+      diag_sign *= ((DType(0) < diag) - (diag < DType(0)));
+      diag_logsum += std::log(std::abs(diag));
+    }
+    sign[i] = (changes % 2 == 1 ? DType(-1) : DType(1)) * diag_sign;
+    logdet[i] = diag_logsum;
+  }
+};
+
+// det = det(A), LU and pivot store the LU decomposition output which will be
+// used in computing gradient
+struct det {
+  template<typename xpu, typename DType>
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 1, DType>& 
det,
+                 const Tensor<xpu, 3, DType>& LU, const Tensor<xpu, 2, int>& 
pivot,
+                 const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    Tensor<xpu, 1, DType> sign = ctx.requested[0]
+      .get_space_typed<xpu, 1, DType>(det.shape_, s);
+    Copy(LU, A, s);
+    // since det(A) = det(trans(A)), so we'll use col-major blas routines here
+    linalg_batch_getrf(LU, pivot, false, s);
+    using namespace mxnet_op;
+    using namespace mshadow::expr;
+    Kernel<SignedLogDet, xpu>::Launch(s, pivot.size(0), pivot.size(1), 
pivot.dptr_,
+                                      LU.dptr_, sign.dptr_, det.dptr_);
+    const_cast<Tensor<xpu, 1, DType>&>(det) = sign * F<mshadow_op::exp>(det);
+  }
+};
+
+// logdet = log(det(A))
+struct logdet {
+  template<typename xpu, typename DType>
+  static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 1, DType>& 
logdet,
+                 const Tensor<xpu, 3, DType>& LU, const Tensor<xpu, 2, int>& 
pivot,
+                 const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    Tensor<xpu, 1, DType> sign = ctx.requested[0]
+      .get_space_typed<xpu, 1, DType>(logdet.shape_, s);
+    Copy(LU, A, s);
+    linalg_batch_getrf(LU, pivot, false, s);
+    using namespace mxnet_op;
+    using namespace mshadow::expr;
+    Kernel<SignedLogDet, xpu>::Launch(s, pivot.size(0), pivot.size(1), 
pivot.dptr_,
+                                      LU.dptr_, sign.dptr_, logdet.dptr_);
+    const_cast<Tensor<xpu, 1, DType>&>(logdet) = F<mshadow_op::log>(sign) + 
logdet;
 
 Review comment:
   I'll remove `logdet`.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to