This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8c74974 support for upper triangular matrices in linalg (#12904)
8c74974 is described below
commit 8c74974e8b03205f08cca0cddf485bd30a24ad73
Author: moin <[email protected]>
AuthorDate: Tue Nov 13 18:11:03 2018 +0100
support for upper triangular matrices in linalg (#12904)
---
src/operator/tensor/la_op-inl.h | 236 ++++++++++++++++-------------
src/operator/tensor/la_op.cc | 21 ++-
src/operator/tensor/la_op.h | 16 ++
tests/python/unittest/test_operator.py | 269 +++++++++++++++++----------------
4 files changed, 300 insertions(+), 242 deletions(-)
diff --git a/src/operator/tensor/la_op-inl.h b/src/operator/tensor/la_op-inl.h
index b3353e2..e89a082 100644
--- a/src/operator/tensor/la_op-inl.h
+++ b/src/operator/tensor/la_op-inl.h
@@ -21,6 +21,7 @@
* Copyright (c) 2017 by Contributors
* \file la_op-inl.h
* \brief Operators for advanced linear algebra.
+ * \note See https://arxiv.org/pdf/1710.08717.pdf for details of gradient
computations.
*/
#ifndef MXNET_OPERATOR_TENSOR_LA_OP_INL_H_
#define MXNET_OPERATOR_TENSOR_LA_OP_INL_H_
@@ -32,20 +33,29 @@ namespace op {
using namespace mshadow;
-// Helper functions.
-struct CopyLowerToUpper {
+// Copies lower/upper triangular part to upper/lower, i.e. to the opposite
side.
+struct CopyTriangularToOppositeSide {
template<typename DType>
- MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType*
data) {
+ MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType*
data, bool to_lower) {
// Below computation works even when we are dealing with a batch of
matrices.
const int row((i % matrix_size) / stride), col(i % stride);
- if ( row > col ) data[i + (col - row) * (stride - 1)] = data[i];
+ if (row > col) {
+ if (to_lower) {
+ data[i] = data[i + (col - row) * (stride - 1)];
+ } else {
+ data[i + (col - row) * (stride - 1)] = data[i];
+ }
+ }
}
};
-struct ZeroUpper {
+
+// Zero's lower/upper triangular part of a matrix.
+struct ZeroTriangular {
template<typename DType>
- MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType*
data) {
+ MSHADOW_XINLINE static void Map(int i, int matrix_size, int stride, DType*
data,
+ bool zero_lower) {
const int row((i % matrix_size) / stride), col(i % stride);
- if ( row < col ) data[i] = 0;
+ if ((!zero_lower && (row < col)) || (zero_lower && (row > col))) data[i] =
0;
}
};
struct Scale {
@@ -103,87 +113,91 @@ struct gemm2 {
}
};
-// L = potrf(A).
+// B = potrf(A).
struct potrf {
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
L,
+ static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
- if ( A.dptr_ != L.dptr_ ) Copy(L, A, s);
- linalg_batch_potrf(L, true, s);
+ const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
+ if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
+ linalg_batch_potrf(B, param.lower, s);
using namespace mxnet_op;
- Kernel<ZeroUpper, xpu>::Launch(s, L.MSize(), L.size(1)*L.stride_,
L.stride_, L.dptr_);
+ Kernel<ZeroTriangular, xpu>::Launch(s, B.MSize(), B.size(1)*B.stride_,
B.stride_,
+ B.dptr_, !param.lower);
}
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
L,
+ static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
- op(A, L, s, attrs);
+ op(A, B, s, attrs);
}
};
-// A = potri(L).
+// A = potri(B).
struct potri {
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>&
A,
+ static void op(const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>&
A,
Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
- if ( A.dptr_ != L.dptr_ ) Copy(A, L, s);
- linalg_batch_potri(A, true, s);
+ const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
+ if ( A.dptr_ != B.dptr_ ) Copy(A, B, s);
+ linalg_batch_potri(A, param.lower, s);
using namespace mxnet_op;
- Kernel<CopyLowerToUpper, xpu>::Launch(s, A.MSize(), A.size(1)*A.stride_,
A.stride_, A.dptr_);
+ Kernel<CopyTriangularToOppositeSide, xpu>::Launch(s, A.MSize(),
A.size(1)*A.stride_, A.stride_,
+ A.dptr_, !param.lower);
}
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>&
A,
+ static void op(const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>&
A,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
- op(L, A, s, attrs);
+ op(B, A, s, attrs);
}
};
-// B = trsm(L,A)
+// C = trsm(A,B)
struct trsm {
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>&
B,
- DType alpha, bool rightside, bool transpose, Stream<xpu> *s) {
- linalg_batch_trsm(L, B, alpha, rightside, true, transpose, s);
+ static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
C,
+ DType alpha, bool rightside, bool lower, bool transpose,
Stream<xpu> *s) {
+ linalg_batch_trsm(A, C, alpha, rightside, lower, transpose, s);
}
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>&
A,
- const Tensor<xpu, 3, DType>& B,
+ static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
+ const Tensor<xpu, 3, DType>& C,
Stream<xpu> *s, const nnvm::NodeAttrs& attrs) {
- if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
+ if ( B.dptr_ != C.dptr_ ) Copy(C, B, s);
const LaTriangMatrixMultParam& param =
nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
- op(L, B, DType(param.alpha), param.rightside, param.transpose, s);
+ op(A, C, DType(param.alpha), param.rightside, param.lower,
param.transpose, s);
}
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>&
A,
- const Tensor<xpu, 3, DType>& B,
+ static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
+ const Tensor<xpu, 3, DType>& C,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
- op(L, A, B, s, attrs);
+ op(A, B, C, s, attrs);
}
};
-// B = trmm(L,A)
+// C = trmm(A,B)
struct trmm {
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>&
B,
- DType alpha, bool rightside, bool transpose, Stream<xpu> *s) {
- linalg_batch_trmm(L, B, alpha, rightside, true, transpose, s);
+ static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
C,
+ DType alpha, bool rightside, bool lower, bool transpose,
Stream<xpu> *s) {
+ linalg_batch_trmm(A, C, alpha, rightside, lower, transpose, s);
}
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>&
A,
- const Tensor<xpu, 3, DType>& B, Stream<xpu> *s,
+ static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
+ const Tensor<xpu, 3, DType>& C, Stream<xpu> *s,
const nnvm::NodeAttrs& attrs) {
- if ( A.dptr_ != B.dptr_ ) Copy(B, A, s);
+ if ( B.dptr_ != C.dptr_ ) Copy(C, B, s);
const LaTriangMatrixMultParam& param =
nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
- op(L, B, DType(param.alpha), param.rightside, param.transpose, s);
+ op(A, C, DType(param.alpha), param.rightside, param.lower,
param.transpose, s);
}
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& L, const Tensor<xpu, 3, DType>&
A,
- const Tensor<xpu, 3, DType>& B, const OpContext& ctx,
+ static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
+ const Tensor<xpu, 3, DType>& C, const OpContext& ctx,
const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
- op(L, A, B, s, attrs);
+ op(A, B, C, s, attrs);
}
};
@@ -223,8 +237,8 @@ struct syrk {
linalg_batch_syrk(A, B, alpha, beta, tA, s);
// Symmetric B is in lower triangle: Copy to upper
using namespace mxnet_op;
- Kernel<CopyLowerToUpper, xpu>::Launch(s, B.MSize(), B.size(1)*B.stride_,
- B.stride_, B.dptr_);
+ Kernel<CopyTriangularToOppositeSide, xpu>::Launch(s, B.MSize(),
B.size(1)*B.stride_,
+ B.stride_, B.dptr_, false);
}
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
@@ -276,8 +290,8 @@ struct gelqf {
Tensor<xpu, 2, DType> QLeft(Qi.dptr_, Shape2(m, m), Qi.stride_, s);
Copy(Li, QLeft, s);
using namespace mxnet_op;
- Kernel<ZeroUpper, xpu>::Launch(s, Li.MSize(), m*Li.stride_, Li.stride_,
- Li.dptr_);
+ Kernel<ZeroTriangular, xpu>::Launch(s, Li.MSize(), m*Li.stride_,
Li.stride_,
+ Li.dptr_, false);
// Call orglq: Input is Qi and part of work. Overwrites Qi by final Q
// matrix (conversion from internal representation)
linalg_orglq(Qi, work, s);
@@ -395,117 +409,129 @@ struct gemm2_backward {
struct potrf_backward {
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>&
L,
+ static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>&
B,
const Tensor<xpu, 3, DType>& dA,
Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
- // Backward of L = potrf(A).
- // dA = 0.5 * L**T * copyLTU(L**T * dL) * L**(-1)
+ // Backward of B = potrf(A).
+ // dA = 0.5 * B**T * copyLTU(B**T * dB) * B**(-1)
// Here, copyLTU(M) creates a symmetric matrix from the square matrix M
// by setting the upper triangle to be equal to the lower triangle, leaving
// lower triangle and diagonal unchanged.
- if ( dL.dptr_ != dA.dptr_ ) {
- Copy(dA, dL, s);
+ // The function also handles the case when B is upper triangular by
appropriate
+ // transpositions.
+ const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
+ if ( dB.dptr_ != dA.dptr_ ) {
+ Copy(dA, dB, s);
}
- trmm::op(L, dA, DType(1.0), false, true, s);
+ trmm::op(B, dA, DType(1.0), !param.lower, param.lower, true, s);
using namespace mxnet_op;
- Kernel<CopyLowerToUpper, xpu>::Launch
- (s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_, dA.dptr_);
- trsm::op(L, dA, DType(1.0), false, true, s);
- trsm::op(L, dA, DType(0.5), true, false, s);
+ Kernel<CopyTriangularToOppositeSide, xpu>::Launch
+ (s, dA.MSize(), dA.size(1)*dA.stride_, dA.stride_, dA.dptr_,
!param.lower);
+ trsm::op(B, dA, DType(1.0), false, param.lower, param.lower, s);
+ trsm::op(B, dA, DType(0.5), true, param.lower, !param.lower, s);
}
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>&
L,
+ static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>&
B,
const Tensor<xpu, 3, DType>& dA,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
- op(dL, L, dA, s, attrs);
+ op(dB, B, dA, s, attrs);
}
};
struct potri_backward {
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>&
L,
- const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
dL,
+ static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>&
B,
+ const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
dB,
Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
- // Backward of A = potri(L).
- // dL = -tril( A * (dA + dA**T) * L**(-T)), where tril() extracts lower
triangle
+ // Backward of A = potri(B).
+ // dB = -tril( A * (dA + dA**T) * B**(-T)), where tril() extracts lower
triangle
// and diagonal. We must not assume that dA is symmetric.
+ // The function also handles the case when B is upper triangular by
appropriate
+ // transpositions.
// Note: Calling gemm twice here is a bit wasteful, but otherwise the
symmetrization
// of dA would require temporary memory.
- gemm::op(A, dA, dL, DType(1.), DType(0.), false, false, s);
- gemm::op(A, dA, dL, DType(1.), DType(1.), false, true, s);
- trsm::op(L, dL, DType(-1.), true, true, s);
+ const LaCholeskyParam& param = nnvm::get<LaCholeskyParam>(attrs.parsed);
+ if (param.lower) {
+ gemm::op(A, dA, dB, DType(1.), DType(0.), false, false, s);
+ gemm::op(A, dA, dB, DType(1.), DType(1.), false, true, s);
+ } else {
+ gemm::op(dA, A, dB, DType(1.), DType(0.), false, false, s);
+ gemm::op(dA, A, dB, DType(1.), DType(1.), true, false, s);
+ }
+ trsm::op(B, dB, DType(-1.), param.lower, param.lower, true, s);
using namespace mxnet_op;
- Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_,
dL.stride_,
- dL.dptr_);
+ Kernel<ZeroTriangular, xpu>::Launch(s, dB.MSize(), dB.size(1)*dB.stride_,
dB.stride_,
+ dB.dptr_, !param.lower);
}
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>&
L,
- const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
dL,
+ static void op(const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>&
B,
+ const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
dB,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
- op(dA, L, A, dL, s, attrs);
+ op(dA, B, A, dB, s, attrs);
}
};
struct trsm_backward {
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>&
L,
- const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
- const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>&
dA,
+ static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>&
A,
+ const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>&
C,
+ const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>&
dB,
Stream<xpu>* s, const nnvm::NodeAttrs& attrs) {
- // Backward of B = trsm(L,A).
+ // Backward of C = trsm(A,B).
const LaTriangMatrixMultParam& param =
nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
+ // Compute dB
+ if ( dB.dptr_ != dC.dptr_ ) Copy(dB, dC, s);
+ trsm::op(A, dB, DType(param.alpha), param.rightside, param.lower,
!param.transpose, s);
// Compute dA
- if ( dA.dptr_ != dB.dptr_ ) Copy(dA, dB, s);
- trsm::op(L, dA, DType(param.alpha), param.rightside, !param.transpose, s);
- // Compute dL
const bool da_left(param.rightside == param.transpose);
DType scale(-1.0/param.alpha);
- (da_left ? gemm::op(dA, B, dL, scale, DType(0), param.transpose,
!param.transpose, s)
- : gemm::op(B, dA, dL, scale, DType(0), !param.transpose,
param.transpose, s));
+ (da_left ? gemm::op(dB, C, dA, scale, DType(0), param.transpose,
!param.transpose, s)
+ : gemm::op(C, dB, dA, scale, DType(0), !param.transpose,
param.transpose, s));
using namespace mxnet_op;
- Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_,
dL.stride_, dL.dptr_);
+ Kernel<ZeroTriangular, xpu>::Launch(s, dA.MSize(), dA.size(1)*dA.stride_,
dA.stride_,
+ dA.dptr_, !param.lower);
}
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>&
L,
- const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
B,
- const Tensor<xpu, 3, DType>& dL, const Tensor<xpu, 3, DType>&
dA,
+ static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>&
A,
+ const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>&
C,
+ const Tensor<xpu, 3, DType>& dA, const Tensor<xpu, 3, DType>&
dB,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
- op(dB, L, A, B, dL, dA, s, attrs);
+ op(dC, A, B, C, dA, dB, s, attrs);
}
};
struct trmm_backward {
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>&
L,
- const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
dL,
- const Tensor<xpu, 3, DType>& dA, Stream<xpu>* s,
+ static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>&
A,
+ const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>&
dA,
+ const Tensor<xpu, 3, DType>& dB, Stream<xpu>* s,
const nnvm::NodeAttrs& attrs) {
- // Backward of B = trmm(L,A).
+ // Backward of C = trmm(A,B).
const LaTriangMatrixMultParam& param =
nnvm::get<LaTriangMatrixMultParam>(attrs.parsed);
- // Compute dL
+ // Compute dA
DType scale(param.alpha);
if (param.rightside == param.transpose) {
- gemm::op(dB, A, dL, scale, DType(0.), param.transpose, !param.transpose,
s);
+ gemm::op(dC, B, dA, scale, DType(0.), param.transpose, !param.transpose,
s);
} else {
- gemm::op(A, dB, dL, scale, DType(0.), !param.transpose, param.transpose,
s);
+ gemm::op(B, dC, dA, scale, DType(0.), !param.transpose, param.transpose,
s);
}
using namespace mxnet_op;
- Kernel<ZeroUpper, xpu>::Launch(s, dL.MSize(), dL.size(1)*dL.stride_,
dL.stride_,
- dL.dptr_);
- // Compute dA
- if (dA.dptr_ != dB.dptr_) Copy(dA, dB, s);
- trmm::op(L, dA, scale, param.rightside, !param.transpose, s);
+ Kernel<ZeroTriangular, xpu>::Launch(s, dA.MSize(), dA.size(1)*dA.stride_,
dA.stride_,
+ dA.dptr_, !param.lower);
+ // Compute dB
+ if (dB.dptr_ != dC.dptr_) Copy(dB, dC, s);
+ trmm::op(A, dB, scale, param.rightside, param.lower, !param.transpose, s);
}
template<typename xpu, typename DType>
- static void op(const Tensor<xpu, 3, DType>& dB, const Tensor<xpu, 3, DType>&
L,
- const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
dL,
- const Tensor<xpu, 3, DType>& dA, const OpContext& ctx,
+ static void op(const Tensor<xpu, 3, DType>& dC, const Tensor<xpu, 3, DType>&
A,
+ const Tensor<xpu, 3, DType>& B, const Tensor<xpu, 3, DType>&
dA,
+ const Tensor<xpu, 3, DType>& dB, const OpContext& ctx,
const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
- op(dB, L, A, dL, dA, s, attrs);
+ op(dC, A, B, dA, dB, s, attrs);
}
};
@@ -586,13 +612,13 @@ struct gelqf_backward {
Tensor<xpu, 3, DType> tempM = ctx.requested[0]
.get_space_typed<xpu, 3, DType>(dL.shape_, s);
Copy(tempM, dL, s);
- trmm::op(L, tempM, DType(1.0), false, true, s);
+ trmm::op(L, tempM, DType(1.0), false, true, true, s);
gemm::op(dA, Q, tempM, DType(-1.0), DType(1.0), false, true, s);
- Kernel<CopyLowerToUpper, xpu>::Launch
+ Kernel<CopyTriangularToOppositeSide, xpu>::Launch
(s, tempM.MSize(), tempM.size(1)*tempM.stride_, tempM.stride_,
- tempM.dptr_);
+ tempM.dptr_, false);
gemm::op(tempM, Q, dA, DType(1.0), DType(1.0), false, false, s);
- trsm::op(L, dA, DType(1.0), false, true, s);
+ trsm::op(L, dA, DType(1.0), false, true, true, s);
}
};
diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc
index 91bcdd3..f8a130d 100644
--- a/src/operator/tensor/la_op.cc
+++ b/src/operator/tensor/la_op.cc
@@ -30,6 +30,7 @@ namespace op {
DMLC_REGISTER_PARAMETER(LaMatrixMacParam);
DMLC_REGISTER_PARAMETER(LaMatrixMultParam);
+DMLC_REGISTER_PARAMETER(LaCholeskyParam);
DMLC_REGISTER_PARAMETER(LaTriangMatrixMultParam);
DMLC_REGISTER_PARAMETER(LaSyrkParam);
@@ -178,11 +179,12 @@ NNVM_REGISTER_OP(_linalg_potrf)
.describe(R"code(Performs Cholesky factorization of a symmetric
positive-definite matrix.
Input is a tensor *A* of dimension *n >= 2*.
-If *n=2*, the Cholesky factor *L* of the symmetric, positive definite matrix
*A* is
-computed. *L* is lower triangular (entries of upper triangle are all zero), has
+If *n=2*, the Cholesky factor *B* of the symmetric, positive definite matrix
*A* is
+computed. *B* is triangular (entries of upper or lower triangle are all zero),
has
positive diagonal entries, and:
- *A* = *L* \* *L*\ :sup:`T`
+ *A* = *B* \* *B*\ :sup:`T` if *lower* = *true*
+ *A* = *B*\ :sup:`T` \* *B* if *lower* = *false*
If *n>2*, *potrf* is performed separately on the trailing two dimensions for
all inputs
(batch mode).
@@ -201,6 +203,7 @@ Examples::
)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
+.set_attr_parser(ParamParser<LaCholeskyParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs)
{ return std::vector<std::string>{"A"}; } )
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
@@ -214,6 +217,7 @@ Examples::
NNVM_REGISTER_OP(_backward_linalg_potrf)
.set_num_inputs(2)
.set_num_outputs(1)
+.set_attr_parser(ParamParser<LaCholeskyParam>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs)
{ return std::vector<std::pair<int, int> >{{0, 0}}; })
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
@@ -227,10 +231,11 @@ NNVM_REGISTER_OP(_linalg_potri)
.describe(R"code(Performs matrix inversion from a Cholesky factorization.
Input is a tensor *A* of dimension *n >= 2*.
-If *n=2*, *A* is a lower triangular matrix (entries of upper triangle are all
zero)
+If *n=2*, *A* is a triangular matrix (entries of upper or lower triangle are
all zero)
with positive diagonal. We compute:
- *out* = *A*\ :sup:`-T` \* *A*\ :sup:`-1`
+ *out* = *A*\ :sup:`-T` \* *A*\ :sup:`-1` if *lower* = *true*
+ *out* = *A*\ :sup:`-1` \* *A*\ :sup:`-T` if *lower* = *false*
In other words, if *A* is the Cholesky factor of a symmetric positive definite
matrix
*B* (obtained by *potrf*), then
@@ -259,6 +264,7 @@ Examples::
)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
+.set_attr_parser(ParamParser<LaCholeskyParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs)
{ return std::vector<std::string>{"A"}; } )
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
@@ -272,6 +278,7 @@ Examples::
NNVM_REGISTER_OP(_backward_linalg_potri)
.set_num_inputs(3)
.set_num_outputs(1)
+.set_attr_parser(ParamParser<LaCholeskyParam>)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
{ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
@@ -283,7 +290,7 @@ NNVM_REGISTER_OP(_linalg_trmm)
Input are tensors *A*, *B*, each of dimension *n >= 2* and having the same
shape
on the leading *n-2* dimensions.
-If *n=2*, *A* must be lower triangular. The operator performs the BLAS3
function
+If *n=2*, *A* must be triangular. The operator performs the BLAS3 function
*trmm*:
*out* = *alpha* \* *op*\ (*A*) \* *B*
@@ -346,7 +353,7 @@ NNVM_REGISTER_OP(_linalg_trsm)
Input are tensors *A*, *B*, each of dimension *n >= 2* and having the same
shape
on the leading *n-2* dimensions.
-If *n=2*, *A* must be lower triangular. The operator performs the BLAS3
function
+If *n=2*, *A* must be triangular. The operator performs the BLAS3 function
*trsm*, solving for *out* in:
*op*\ (*A*) \* *out* = *alpha* \* *B*
diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h
index 433789c..0327dd1 100644
--- a/src/operator/tensor/la_op.h
+++ b/src/operator/tensor/la_op.h
@@ -81,10 +81,22 @@ struct LaMatrixMultParam : public
dmlc::Parameter<LaMatrixMultParam> {
}
};
+// Parameters for Cholesky factorization and matrix inversion
+struct LaCholeskyParam : public dmlc::Parameter<LaCholeskyParam> {
+ bool lower;
+ DMLC_DECLARE_PARAMETER(LaCholeskyParam) {
+ DMLC_DECLARE_FIELD(lower)
+ .set_default(true)
+ .describe
+ ("True if the triangular matrix is lower triangular, false if it is
upper triangular.");
+ }
+};
+
// Parameters for matrix-matrix multiplication where one is a triangular
matrix.
struct LaTriangMatrixMultParam : public
dmlc::Parameter<LaTriangMatrixMultParam> {
bool transpose;
bool rightside;
+ bool lower;
double alpha;
DMLC_DECLARE_PARAMETER(LaTriangMatrixMultParam) {
DMLC_DECLARE_FIELD(transpose)
@@ -93,6 +105,10 @@ struct LaTriangMatrixMultParam : public
dmlc::Parameter<LaTriangMatrixMultParam>
DMLC_DECLARE_FIELD(rightside)
.set_default(false)
.describe("Multiply triangular matrix from the right to non-triangular
one.");
+ DMLC_DECLARE_FIELD(lower)
+ .set_default(true)
+ .describe
+ ("True if the triangular matrix is lower triangular, false if it is
upper triangular.");
DMLC_DECLARE_FIELD(alpha)
.set_default(1.0)
.describe("Scalar factor to be applied to the result.");
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 6f0c98e..7ff4228 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -5229,7 +5229,7 @@ def _make_symm_symbol(a, ndims):
tr_shape = tuple(tr_shape)
return 0.5 * (a + mx.sym.transpose(a, axes=tr_shape))
-def _make_lower_triangle_symm(a, ndims, m, dtype=np.float32):
+def _make_triangle_symm(a, ndims, m, lower, dtype=np.float32):
assert ndims >= 2
# The last two dimensions must both be m
# Create mask for lower triangle and diagonal
@@ -5240,6 +5240,9 @@ def _make_lower_triangle_symm(a, ndims, m,
dtype=np.float32):
index = mx.sym.arange(start=0, stop=m-j, step=1, dtype=np.int32)
part2 = mx.sym.one_hot(index, depth=m, dtype=dtype)
lt_mask = lt_mask + mx.sym.concat(*[part1, part2], dim=0)
+ if not lower:
+ lt_mask = mx.sym.reshape(lt_mask, shape=(m, m))
+ lt_mask = mx.sym.transpose(lt_mask, axes=(1, 0))
shp = tuple([1]*(ndims-2) + [m, m])
lt_mask = mx.sym.reshape(lt_mask, shape=shp)
return mx.sym.broadcast_mul(a, lt_mask)
@@ -5381,141 +5384,147 @@ def test_laop():
check_grad(test_gemm, [a2, b2])
# Now test all the other operators.
+ for lower in [True, False]:
+ upper = not lower
+
+ # Tests with trivial 1x1 matrices.
+ shape = (4, 4, 1, 1)
+ data_in = np.random.uniform(1, 10, shape)
+ # test potrf
+ # Note: Have to symmetrize input, for gradient test to work
+ res_potrf = np.sqrt(data_in)
+ test_potrf = mx.sym.linalg.potrf(data1, lower=lower)
+ check_fw(test_potrf, [data_in], [res_potrf])
+ if grad_check == 1:
+ check_grad(test_potrf, [data_in])
+ # test potri
+ ones = mx.nd.ones(shape).asnumpy()
+ res_potri = np.divide(ones, data_in * data_in)
+ test_potri = mx.sym.linalg.potri(data1, lower=lower)
+ check_fw(test_potri, [data_in], [res_potri])
+ if grad_check == 1:
+ check_grad(test_potri, [data_in])
+ # test trsm
+ trian_in = data_in * 7.
+ test_trsm = mx.sym.linalg.trsm(data1, data2, alpha=7., lower=lower)
+ check_fw(test_trsm, [trian_in, data_in], [ones])
+ if grad_check == 1:
+ check_grad(test_trsm, [trian_in,data_in])
+ # test trmm
+ trian_in = np.divide(ones, trian_in)
+ test_trmm = mx.sym.linalg.trmm(data1, data2, alpha=7., transpose=True,
+ rightside=True, lower=lower)
+ check_fw(test_trmm, [trian_in, data_in], [ones])
+ if grad_check == 1:
+ check_grad(test_trmm, [trian_in, data_in])
+ # test sumlogdiag
+ res_sumlogdiag = np.reshape(np.log(data_in), (4, 4))
+ test_sumlogdiag = mx.sym.linalg.sumlogdiag(data1)
+ check_fw(test_sumlogdiag, [data_in], [res_sumlogdiag])
+ if grad_check == 1:
+ check_grad(test_sumlogdiag, [data_in])
+
+ # more elaborate example of Cholesky factorization
+ matrix = np.array([[9., 3., -6., 12.],
+ [3., 26., -7., -11.],
+ [-6., -7., 9., 7.],
+ [12., -11., 7., 65.]])
+ trian = np.array([[3., 0., 0., 0.],
+ [1., 5., 0., 0.],
+ [-2., -1., 2., 0.],
+ [4., -3., 6., 2.]])
+ pow = np.array([[2., 1., 1., 1.],
+ [1., 4., 1., 1.],
+ [1., 1., 8., 1.],
+ [1., 1., 1., 16.]])
+ inv = np.array([[8.95/3., 0.05/3., 2.65, -2.5/3.],
+ [0.05/3., 0.05, 0.05, 0.],
+ [2.65, 0.05, 2.5, -0.75],
+ [-2.5/3., 0., -0.75, 0.25]])
+ ident = np.eye(4)
+
+ low_trian = trian
+ if not lower:
+ trian = np.transpose(trian)
+
+ # test potrf
+ test_potrf = mx.sym.linalg.potrf(_make_symm_symbol(data1, ndims=4),
lower=lower)
+ a = rep_3x(matrix, 4, 4)
+ r = rep_3x(trian, 4, 4)
+ check_fw(test_potrf, [a], [r])
+ if grad_check == 1:
+ check_grad(test_potrf, [a])
+
+ #test potri
+ data1_ltri = _make_triangle_symm(
+ data1, ndims=4, m=4, lower=lower, dtype=dtype)
+ test_potri = mx.sym.linalg.potri(data1_ltri, lower=lower)
+ a = rep_3x(trian, 4, 4)
+ r = rep_3x(inv, 4, 4)
+ check_fw(test_potri, [a], [r])
+ if grad_check == 1:
+ check_grad(test_potri, [a])
+
+ # test trsm
+ test_trsm = mx.sym.linalg.trsm(data1_ltri, data2, alpha=7.,
transpose=upper, lower=lower)
+ a = rep_3x(trian, 4, 4)
+ b = rep_3x(matrix, 4, 4)
+ r = rep_3x(7. * np.transpose(low_trian), 4, 4)
+ check_fw(test_trsm, [a, b], [r])
+ if grad_check == 1:
+ check_grad(test_trsm, [a, b])
- # Tests with trivial 1x1 matrices.
- shape = (4, 4, 1, 1)
- data_in = np.random.uniform(1, 10, shape)
- # test potrf
- # Note: Have to symmetrize input, for gradient test to work
- res_potrf = np.sqrt(data_in)
- test_potrf = mx.sym.linalg.potrf(data1)
- check_fw(test_potrf, [data_in], [res_potrf])
- if grad_check == 1:
- check_grad(test_potrf, [data_in])
- # test potri
- ones = mx.nd.ones(shape).asnumpy()
- res_potri = np.divide(ones, data_in * data_in)
- test_potri = mx.sym.linalg.potri(data1)
- check_fw(test_potri, [data_in], [res_potri])
- if grad_check == 1:
- check_grad(test_potri, [data_in])
- # test trsm
- trian_in = data_in * 7.
- test_trsm = mx.sym.linalg.trsm(data1, data2, alpha=7.)
- check_fw(test_trsm, [trian_in, data_in], [ones])
- if grad_check == 1:
- check_grad(test_trsm, [trian_in,data_in])
- # test trmm
- trian_in = np.divide(ones, trian_in)
- test_trmm = mx.sym.linalg.trmm(data1, data2, alpha=7., transpose=True,
- rightside=True)
- check_fw(test_trmm, [trian_in, data_in], [ones])
- if grad_check == 1:
- check_grad(test_trmm, [trian_in, data_in])
- # test sumlogdiag
- res_sumlogdiag = np.reshape(np.log(data_in), (4, 4))
- test_sumlogdiag = mx.sym.linalg.sumlogdiag(data1)
- check_fw(test_sumlogdiag, [data_in], [res_sumlogdiag])
- if grad_check == 1:
- check_grad(test_sumlogdiag, [data_in])
-
- # more elaborate example of Cholesky factorization
- matrix = np.array([[9., 3., -6., 12.],
- [3., 26., -7., -11.],
- [-6., -7., 9., 7.],
- [12., -11., 7., 65.]])
- trian = np.array([[3., 0., 0., 0.],
- [1., 5., 0., 0.],
- [-2., -1., 2., 0.],
- [4., -3., 6., 2.]])
- pow = np.array([[2., 1., 1., 1.],
- [1., 4., 1., 1.],
- [1., 1., 8., 1.],
- [1., 1., 1., 16.]])
- inv = np.array([[8.95/3., 0.05/3., 2.65, -2.5/3.],
- [0.05/3., 0.05, 0.05, 0.],
- [2.65, 0.05, 2.5, -0.75],
- [-2.5/3., 0., -0.75, 0.25]])
- ident = np.eye(4)
-
- # test potrf
- test_potrf = mx.sym.linalg.potrf(_make_symm_symbol(data1, ndims=4))
- a = rep_3x(matrix, 4, 4)
- r = rep_3x(trian, 4, 4)
- check_fw(test_potrf, [a], [r])
- if grad_check == 1:
- check_grad(test_potrf, [a])
-
- #test potri
- data1_ltri = _make_lower_triangle_symm(
- data1, ndims=4, m=4, dtype=dtype)
- test_potri = mx.sym.linalg.potri(data1_ltri)
- a = rep_3x(trian, 4, 4)
- r = rep_3x(inv, 4, 4)
- check_fw(test_potri, [a], [r])
- if grad_check == 1:
- check_grad(test_potri, [a])
-
- # test trsm
- test_trsm = mx.sym.linalg.trsm(data1_ltri, data2, alpha=7.)
- a = rep_3x(trian, 4, 4)
- b = rep_3x(matrix, 4, 4)
- r = rep_3x(7. * np.transpose(trian), 4, 4)
- check_fw(test_trsm, [a, b], [r])
- if grad_check == 1:
- check_grad(test_trsm, [a, b])
-
- test_trsm2 = mx.sym.linalg.trsm(
- data1_ltri, data2, alpha=-2., rightside=True, transpose=True)
- r = rep_3x(-2. * trian, 4, 4)
- check_fw(test_trsm2, [a, b], [r])
- if grad_check == 1:
- check_grad(test_trsm2, [a, b])
+ test_trsm2 = mx.sym.linalg.trsm(
+ data1_ltri, data2, alpha=-2., rightside=True, transpose=lower,
lower=lower)
+ r = rep_3x(-2. * low_trian, 4, 4)
+ check_fw(test_trsm2, [a, b], [r])
+ if grad_check == 1:
+ check_grad(test_trsm2, [a, b])
- test_trsm3 = mx.sym.linalg.trsm(
- data1_ltri, data2, alpha=0.5, transpose=True)
- b = rep_3x(np.transpose(trian), 4, 4)
- r = rep_3x(0.5 * ident, 4, 4)
- check_fw(test_trsm3, [a, b], [r])
- if grad_check == 1:
- check_grad(test_trsm3, [a, b])
+ test_trsm3 = mx.sym.linalg.trsm(
+ data1_ltri, data2, alpha=0.5, transpose=lower, lower=lower)
+ b = rep_3x(np.transpose(low_trian), 4, 4)
+ r = rep_3x(0.5 * ident, 4, 4)
+ check_fw(test_trsm3, [a, b], [r])
+ if grad_check == 1:
+ check_grad(test_trsm3, [a, b])
- test_trsm4 = mx.sym.linalg.trsm(
- data1_ltri, data2, alpha=-0.5, rightside=True)
- b = rep_3x(trian, 4, 4)
- r = rep_3x(-0.5 * ident, 4, 4)
- check_fw(test_trsm4, [a, b], [r])
- if grad_check == 1:
- check_grad(test_trsm4, [a, b])
-
- # test trmm
- test_trmm = mx.sym.linalg.trmm(
- data1_ltri, data2, alpha=7., transpose=True, rightside=True)
- a = rep_3x(trian, 4, 4)
- b = rep_3x(matrix, 4, 4)
- r = rep_3x(7. * np.dot(matrix, trian.T), 4, 4)
- check_fw(test_trmm, [a, b], [r])
- if grad_check == 1:
- check_grad(test_trmm, [a, b])
+ test_trsm4 = mx.sym.linalg.trsm(
+ data1_ltri, data2, alpha=-0.5, rightside=True, transpose=upper,
lower=lower)
+ b = rep_3x(low_trian, 4, 4)
+ r = rep_3x(-0.5 * ident, 4, 4)
+ check_fw(test_trsm4, [a, b], [r])
+ if grad_check == 1:
+ check_grad(test_trsm4, [a, b])
+
+ # test trmm
+ test_trmm = mx.sym.linalg.trmm(
+ data1_ltri, data2, alpha=7., transpose=True, rightside=True,
lower=lower)
+ a = rep_3x(trian, 4, 4)
+ b = rep_3x(matrix, 4, 4)
+ r = rep_3x(7. * np.dot(matrix, trian.T), 4, 4)
+ check_fw(test_trmm, [a, b], [r])
+ if grad_check == 1:
+ check_grad(test_trmm, [a, b])
- test_trmm2 = mx.sym.linalg.trmm(data1_ltri, data2, alpha=-2.)
- r = rep_3x(-2. * np.dot(trian, matrix), 4, 4)
- check_fw(test_trmm2, [a, b], [r])
- if grad_check == 1:
- check_grad(test_trmm2, [a, b])
+ test_trmm2 = mx.sym.linalg.trmm(data1_ltri, data2, alpha=-2.,
lower=lower)
+ r = rep_3x(-2. * np.dot(trian, matrix), 4, 4)
+ check_fw(test_trmm2, [a, b], [r])
+ if grad_check == 1:
+ check_grad(test_trmm2, [a, b])
- test_trmm3 = mx.sym.linalg.trmm(data1_ltri, data2, rightside=True)
- r = rep_3x(np.dot(matrix, trian), 4, 4)
- check_fw(test_trmm3, [a, b], [r])
- if grad_check == 1:
- check_grad(test_trmm3, [a, b])
+ test_trmm3 = mx.sym.linalg.trmm(data1_ltri, data2, rightside=True,
lower=lower)
+ r = rep_3x(np.dot(matrix, trian), 4, 4)
+ check_fw(test_trmm3, [a, b], [r])
+ if grad_check == 1:
+ check_grad(test_trmm3, [a, b])
- test_trmm4 = mx.sym.linalg.trmm(
- data1_ltri, data2, alpha=1.2, transpose=True)
- r = rep_3x(1.2 * np.dot(trian.T, matrix), 4, 4)
- check_fw(test_trmm4, [a, b], [r])
- if grad_check == 1:
- check_grad(test_trmm4, [a, b])
+ test_trmm4 = mx.sym.linalg.trmm(
+ data1_ltri, data2, alpha=1.2, transpose=True, lower=lower)
+ r = rep_3x(1.2 * np.dot(trian.T, matrix), 4, 4)
+ check_fw(test_trmm4, [a, b], [r])
+ if grad_check == 1:
+ check_grad(test_trmm4, [a, b])
# test sumlogdiag
a = rep_3x(pow, 4, 4)