This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 68181ef Operator linalg_syevd: Symmetric eigendecomposition (#7966)
68181ef is described below
commit 68181ef42d631293ea0d41e23bf54e4097d83aea
Author: mseeger <[email protected]>
AuthorDate: Wed Sep 27 20:34:42 2017 +0200
Operator linalg_syevd: Symmetric eigendecomposition (#7966)
---
docs/api/python/ndarray/linalg.md | 1 +
docs/api/python/symbol/linalg.md | 1 +
src/operator/c_lapack_api.h | 38 ++++++
src/operator/linalg.h | 21 ++++
src/operator/linalg_impl.h | 57 +++++++++
src/operator/tensor/la_op.cc | 79 +++++++++++-
src/operator/tensor/la_op.h | 93 +++++++++++++-
src/operator/tensor/la_op_inline.h | 131 +++++++++++++++++++
tests/python/unittest/test_operator.py | 224 +++++++++++++++++++++++++++------
9 files changed, 602 insertions(+), 43 deletions(-)
diff --git a/docs/api/python/ndarray/linalg.md
b/docs/api/python/ndarray/linalg.md
index 9b3ee6c..0a85b48 100644
--- a/docs/api/python/ndarray/linalg.md
+++ b/docs/api/python/ndarray/linalg.md
@@ -37,6 +37,7 @@ In the rest of this document, we list routines provided by
the `ndarray.linalg`
sumlogdiag
syrk
gelqf
+ syevd
```
## API Reference
diff --git a/docs/api/python/symbol/linalg.md b/docs/api/python/symbol/linalg.md
index 85b8b2a..d22ca8e 100644
--- a/docs/api/python/symbol/linalg.md
+++ b/docs/api/python/symbol/linalg.md
@@ -37,6 +37,7 @@ In the rest of this document, we list routines provided by
the `symbol.linalg` p
sumlogdiag
syrk
gelqf
+ syevd
```
## API Reference
diff --git a/src/operator/c_lapack_api.h b/src/operator/c_lapack_api.h
index 8d774af..53b0bf2 100644
--- a/src/operator/c_lapack_api.h
+++ b/src/operator/c_lapack_api.h
@@ -109,6 +109,13 @@ extern "C" {
MXNET_LAPACK_FSIG_ORGQR(sorgqr, float)
MXNET_LAPACK_FSIG_ORGQR(dorgqr, double)
+
+ #define MXNET_LAPACK_FSIG_SYEVD(func, dtype) \
+ void func##_(char *jobz, char *uplo, int *n, dtype *a, int *lda, dtype *w,
\
+ dtype *work, int *lwork, int *iwork, int *liwork, int *info);
+
+ MXNET_LAPACK_FSIG_SYEVD(ssyevd, float)
+ MXNET_LAPACK_FSIG_SYEVD(dsyevd, double)
}
#define MXNET_LAPACK_ROW_MAJOR 101
@@ -237,6 +244,26 @@ inline void flip<cpu, double>(int m, int n,
MXNET_LAPACK_CWRAP_ORGLQ(s, float)
MXNET_LAPACK_CWRAP_ORGLQ(d, double)
+ // Note: Supports row-major format only. Internally, column-major is used,
so all
+ // inputs/outputs are flipped (in particular, uplo is flipped).
+ #define MXNET_LAPACK_CWRAP_SYEVD(func, dtype) \
+ inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype
*a, \
+ int lda, dtype *w, dtype *work, int lwork, \
+ int *iwork, int liwork) { \
+ if (matrix_layout == MXNET_LAPACK_ROW_MAJOR) { \
+ int info(0); \
+ char jobz('V'); \
+ char uplo_(loup(uplo, true)); \
+ func##_(&jobz, &uplo_, &n, a, &lda, w, work, &lwork, iwork, &liwork,
&info); \
+ return info; \
+ } else { \
+ CHECK(false) << "MXNET_LAPACK_" << #func << " implemented for row-major
layout only"; \
+ return 1; \
+ } \
+ }
+ MXNET_LAPACK_CWRAP_SYEVD(ssyevd, float)
+ MXNET_LAPACK_CWRAP_SYEVD(dsyevd, double)
+
#else
// use pragma message instead of warning
@@ -258,6 +285,14 @@ inline void flip<cpu, double>(int m, int n,
return 1; \
}
+ #define MXNET_LAPACK_CWRAPPER3(func, dtype) \
+ inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype
*a, \
+ int lda, dtype *w, dtype *work, int lwork, \
+ int *iwork, int liwork) { \
+ LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not
available."; \
+ return 1; \
+ }
+
#define MXNET_LAPACK_UNAVAILABLE(func) \
inline int mxnet_lapack_##func(...) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not
available."; \
@@ -277,6 +312,9 @@ inline void flip<cpu, double>(int m, int n,
MXNET_LAPACK_CWRAPPER2(sorglq, float)
MXNET_LAPACK_CWRAPPER2(dorglq, double)
+ MXNET_LAPACK_CWRAPPER3(ssyevd, float)
+ MXNET_LAPACK_CWRAPPER3(dsyevd, double)
+
#endif
template <typename DType>
diff --git a/src/operator/linalg.h b/src/operator/linalg.h
index a404fda..651b8e2 100644
--- a/src/operator/linalg.h
+++ b/src/operator/linalg.h
@@ -163,6 +163,27 @@ template<typename xpu, typename DType>
int linalg_gelqf_workspace_query(const Tensor<xpu, 2, DType>& A,
Stream<xpu> *s = 0);
+//////////////////////////////// SYEVD
////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "syevd". Please refer to the
+// LAPACK documentation for further details.
+// Note:
+// - The current implementation works for CPU only
+// - A is input and output parameter (overwritten by U)
+// - Input A is symmetric, we access the lower triangle only
+// - Requires two workspace arrays, one in DType, other in int.
+
+template<typename xpu, typename DType>
+void linalg_syevd(const Tensor<xpu, 2, DType>& A,
+ const Tensor<xpu, 1, DType>& L,
+ const Tensor<xpu, 1, DType>& work,
+ const Tensor<xpu, 1, int>& iwork, Stream<xpu> *s = 0);
+
+// This function determines the amount of workspace needed for linalg_syevd
+template<typename xpu, typename DType>
+void linalg_syevd_workspace_query(const Tensor<xpu, 2, DType>& A, int* lwork,
+ int* liwork, Stream<xpu> *s = 0);
+
#include "linalg_impl.h"
#endif // MXNET_OPERATOR_LINALG_H_
diff --git a/src/operator/linalg_impl.h b/src/operator/linalg_impl.h
index 647c618..f7293c6 100644
--- a/src/operator/linalg_impl.h
+++ b/src/operator/linalg_impl.h
@@ -966,4 +966,61 @@ LINALG_GPU_GELQF_WORKSPACE_QUERY(D, double)
#endif // __CUDACC__
+//////////////////////////////// SYEVD
////////////////////////////////////////////
+
+// CPU/GPU-versions of LAPACK function "syevd"
+
+template<typename xpu, typename DType> inline
+void check_syevd(const Tensor<xpu, 2, DType>& A,
+ const Tensor<xpu, 1, DType>& L) {
+ // Any checking that helps user debug potential problems.
+ CHECK_EQ(A.size(0), A.size(1))
+ << "A must be square symmetric matrix";
+ CHECK_EQ(A.size(0), L.size(0))
+ << "A, L have incompatible sizes";
+}
+
+#define LINALG_CPU_SYEVD(fname, DType) \
+template<> inline \
+void linalg_syevd<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
+ const Tensor<cpu, 1, DType>& L, \
+ const Tensor<cpu, 1, DType>& work, \
+ const Tensor<cpu, 1, int>& iwork, \
+ Stream<cpu> *s) { \
+ check_syevd(A, L); \
+ int ret(MXNET_LAPACK_##fname(MXNET_LAPACK_ROW_MAJOR, 'L', A.size(0), \
+ A.dptr_, A.stride_, L.dptr_, work.dptr_, \
+ work.size(0), iwork.dptr_, iwork.size(0))); \
+ CHECK_EQ(ret, 0) << #fname << " failed in lapack on cpu."; \
+}
+// LINALG_CPU_SYEVD(ssyevd, float)
+LINALG_CPU_SYEVD(dsyevd, double)
+
+template<> inline
+void linalg_syevd<cpu, float>(const Tensor<cpu, 2, float>& A,
+ const Tensor<cpu, 1, float>& L,
+ const Tensor<cpu, 1, float>& work,
+ const Tensor<cpu, 1, int>& iwork,
+ Stream<cpu> *s) {
+ CHECK(false) << "linalg_syevd is not currently implemented for float32." <<
std::endl
+ << "Please use float64 for now. If the rest of your code runs
on float32,"
+ << " please use the Cast operator.";
+}
+
+#define LINALG_CPU_SYEVD_WORKSPACE_QUERY(func, DType) \
+template<> inline \
+void linalg_syevd_workspace_query<cpu, DType>(const Tensor<cpu, 2, DType>& A, \
+ int* lwork, int* liwork, \
+ Stream<cpu> *s) { \
+ DType work(0.0); \
+ int iwork(0); \
+ int ret(MXNET_LAPACK_##func(MXNET_LAPACK_ROW_MAJOR, 'L', A.size(0), \
+ A.dptr_, A.stride_, &work, &work, -1, &iwork, \
+ -1)); \
+ *lwork = static_cast<int>(work); \
+ *liwork = iwork; \
+}
+LINALG_CPU_SYEVD_WORKSPACE_QUERY(ssyevd, float)
+LINALG_CPU_SYEVD_WORKSPACE_QUERY(dsyevd, double)
+
#endif // MXNET_OPERATOR_LINALG_IMPL_H_
diff --git a/src/operator/tensor/la_op.cc b/src/operator/tensor/la_op.cc
index c332e45..eedbc62 100644
--- a/src/operator/tensor/la_op.cc
+++ b/src/operator/tensor/la_op.cc
@@ -209,7 +209,7 @@ with positive diagonal. We compute:
*out* = *A*\ :sup:`-T` \* *A*\ :sup:`-1`
In other words, if *A* is the Cholesky factor of a symmetric positive definite
matrix
-*B*, then
+*B* (obtained by *potrf*), then
*out* = *B*\ :sup:`-1`
@@ -219,8 +219,8 @@ If *n>2*, *potri* is performed separately on the trailing
two dimensions for all
.. note:: The operator supports float32 and float64 data types only.
.. note:: Use this operator only if you are certain you need the inverse of
*B*, and
- cannot use the Cholesky factor alone. The latter is more numerically
- stable and cheaper.
+ cannot use the Cholesky factor *A* (*potrf*), together with
backsubstitution
+ (*trsm*). The latter is numerically much safer, and also cheaper.
Examples::
@@ -550,5 +550,78 @@ NNVM_REGISTER_OP(_backward_linalg_gelqf)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 4, 1,
gelqf_backward>);
+NNVM_REGISTER_OP(_linalg_syevd)
+.describe(R"code(Eigendecomposition for symmetric matrix.
+Input is a tensor *A* of dimension *n >= 2*.
+
+If *n=2*, *A* must be symmetric, of shape *(x, x)*. We compute the
eigendecomposition,
+resulting in the orthonormal matrix *U* of eigenvectors, shape *(x, x)*, and
the
+vector *L* of eigenvalues, shape *(x,)*, so that:
+
+ *U* \* *A* = *diag(L)* \* *U*
+
+Here:
+
+ *U* \* *U*\ :sup:`T` = *U*\ :sup:`T` \* *U* = *I*
+
+where *I* is the identity matrix. Also, *L(0) <= L(1) <= L(2) <= ...*
(ascending order).
+
+If *n>2*, *syevd* is performed separately on the trailing two dimensions of
*A* (batch
+mode). In this case, *U* has *n* dimensions like *A*, and *L* has *n-1*
dimensions.
+
+.. note:: The operator supports float32 and float64 data types only.
+
+.. note:: For the time being, this operator supports the float64 data type
only. If the
+ rest of your expression uses float32, please apply the Cast operator
to inputs
+ and outputs.
+
+.. note:: Derivatives for this operator are defined only if *A* is such that
all its
+ eigenvalues are distinct, and the eigengaps are not too small. If
you need
+ gradients, do not apply this operator to matrices with multiple
eigenvalues.
+
+Examples::
+
+ // Single symmetric eigendecomposition
+ A = [[1., 2.], [2., 4.]]
+ U, L = syevd(A)
+ U = [[0.89442719, -0.4472136],
+ [0.4472136, 0.89442719]]
+ L = [0., 5.]
+
+ // Batch symmetric eigendecomposition
+ A = [[[1., 2.], [2., 4.]],
+ [[1., 2.], [2., 5.]]]
+ U, L = syevd(A)
+ U = [[[0.89442719, -0.4472136],
+ [0.4472136, 0.89442719]],
+ [[0.92387953, -0.38268343],
+ [0.38268343, 0.92387953]]]
+ L = [[0., 5.],
+ [0.17157288, 5.82842712]]
+)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(2)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs)
+ { return std::vector<std::string>{"A"}; } )
+.set_attr<nnvm::FInferShape>("FInferShape", LaEigFactShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 2>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs)
+ { return std::vector<std::pair<int, int>>{{0, 0}}; })
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
+ { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
+.set_attr<FCompute>("FCompute<cpu>", LaOpForwSyevd<cpu, syevd>)
+.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseOut{"_backward_linalg_syevd"})
+.add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices to be
factorized");
+
+NNVM_REGISTER_OP(_backward_linalg_syevd)
+.set_num_inputs(4)
+.set_num_outputs(1)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs)
+ { return std::vector<std::pair<int, int> >{{0, 0}}; })
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs)
+ { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; })
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", LaOpBackwSyevd<cpu, syevd_backward>);
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h
index b4093f6..a323139 100644
--- a/src/operator/tensor/la_op.h
+++ b/src/operator/tensor/la_op.h
@@ -298,7 +298,49 @@ inline bool LaLQFactShape(const nnvm::NodeAttrs& attrs,
return false;
}
+// Shape inference function for linalg_syevd
+// Inputs: A. Outputs: U, L
+inline bool LaEigFactShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape>* in_attrs,
+ std::vector<TShape>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1);
+ CHECK_EQ(out_attrs->size(), 2);
+ const TShape& in_a = (*in_attrs)[0];
+ const TShape& out_u = (*out_attrs)[0];
+ const TShape& out_l = (*out_attrs)[1];
+ if ( in_a.ndim() >= 2 ) {
+ // Forward shape inference.
+ const int ndim(in_a.ndim());
+ CHECK_EQ(in_a[ndim-2], in_a[ndim-1])
+ << "Input A shape wrong: Last two dimensions must be equal";
+ // U must have same shape as A
+ SHAPE_ASSIGN_CHECK(*out_attrs, 0, in_a);
+ std::vector<int> oshape_l(ndim-1);
+ for ( int i = 0; i < ndim-1; ++i ) {
+ oshape_l[i] = in_a[i];
+ }
+ TShape tshape_l(oshape_l.begin(), oshape_l.end());
+ SHAPE_ASSIGN_CHECK(*out_attrs, 1, tshape_l);
+ return true;
+ }
+ if ( out_u.ndim() >= 2 && out_u.ndim() == out_l.ndim()+1 ) {
+ // Backward shape inference.
+ const int ndim(out_u.ndim());
+ for ( int i = 0; i < ndim-1; ++i ) {
+ CHECK_EQ(out_u[i], out_l[i])
+ << "Outputs U, L must have same dimensions except for last";
+ }
+ CHECK_EQ(out_u[ndim-2], out_u[ndim-1])
+ << "Output U shape wrong: Last two dimensions must be equal";
+ // A must have same shape as U
+ SHAPE_ASSIGN_CHECK(*in_attrs, 0, out_u);
+ return true;
+ }
+ return false;
+}
+
// Adapters for calling the various operators with appropriate signatures.
+
template<typename xpu, typename DType, int idim, int odim, int inum, int onum,
typename laop>
struct LaOpCaller {
static void op(const std::vector<TBlob>& inputs,
@@ -432,7 +474,6 @@ void LaOpForward(const nnvm::NodeAttrs& attrs,
});
}
-
template<typename xpu, int idim, int odim, int inum, int onum, typename laop>
void LaOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
@@ -462,6 +503,56 @@ void LaOpBackward(const nnvm::NodeAttrs& attrs,
});
}
+// Specific wrapper for syevd (cannot use the default ones, because A, U have
+// different dimensionality than L
+
+// (A) => (U, L)
+template<typename xpu, typename laop>
+void LaOpForwSyevd(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mshadow;
+ CHECK_EQ(inputs.size(), 1);
+ CHECK_EQ(outputs.size(), 2);
+ MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+ laop::op(inputs[0].FlatToKD<xpu, 3, OType>(s),
+ outputs[0].FlatToKD<xpu, 3, OType>(s),
+ outputs[1].FlatToKD<xpu, 2, OType>(s), ctx, attrs);
+ });
+}
+
+// (dU, dL, U, L) => (dA)
+template<typename xpu, typename laop>
+void LaOpBackwSyevd(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mshadow;
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+ CHECK_EQ(inputs.size(), 4);
+ CHECK_EQ(outputs.size(), 1);
+ MSHADOW_SGL_DBL_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+ std::vector<TBlob> tspace(outputs);
+ if ( req[0] == kAddTo ) {
+ tspace[0].dptr_ = ctx.requested[0]
+ .get_space_typed<xpu, 1, OType>(Shape1(outputs[0].Size()), s).dptr_;
+ }
+ laop::op(inputs[0].FlatToKD<xpu, 3, OType>(s),
+ inputs[1].FlatToKD<xpu, 2, OType>(s),
+ inputs[2].FlatToKD<xpu, 3, OType>(s),
+ inputs[3].FlatToKD<xpu, 2, OType>(s),
+ tspace[0].FlatToKD<xpu, 3, OType>(s), ctx, attrs);
+ if ( req[0] == kAddTo ) {
+ Tensor<xpu, 1, OType> out = outputs[0].FlatTo1D<xpu, OType>(s);
+ out += tspace[0].FlatTo1D<xpu, OType>(s);
+ }
+ });
+}
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/tensor/la_op_inline.h
b/src/operator/tensor/la_op_inline.h
index 2d338ee..dce653c 100644
--- a/src/operator/tensor/la_op_inline.h
+++ b/src/operator/tensor/la_op_inline.h
@@ -284,6 +284,64 @@ struct gelqf {
}
};
+// If (U, L) = syevd(A) [symmetric eigendecomposition], this helper acts on
each row
+// of U, deciding whether its sign is flipped or not.
+// If u denotes a row, we choose the sign s.t. u_k > 0, where k = argmax|u_j|.
In case
+// of a tie, the smaller index k decides.
+struct SyevdEigenVecSigns {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int i, int n, DType* U, int ldu) {
+ DType* urow(U + (i*ldu));
+ DType maxval(fabs(urow[0])), uval(0.0);
+ int maxind(0);
+ for (int i = 1; i < n; ++i) {
+ uval = fabs(urow[i]);
+ if (uval > maxval) {
+ maxval = uval;
+ maxind = i;
+ }
+ }
+ if (urow[maxind] < 0.0) {
+ // Flip all signs
+ for (int i = 0; i < n; ++i) {
+ urow[i] = -urow[i];
+ }
+ }
+ }
+};
+
+// (U, L) = syevd(A) [symmetric eigendecomposition]
+// - Input A must be symmetric, only lower triangle is used
+// - U can overwrite A
+// - Needs workspace (both DType and int), size of which is determined by a
+// workspace query
+struct syevd {
+ template<typename xpu, typename DType>
+ static void op(const Tensor<xpu, 3, DType>& A, const Tensor<xpu, 3, DType>&
U,
+ const Tensor<xpu, 2, DType>& L, const OpContext& ctx,
+ const nnvm::NodeAttrs& attrs) {
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+ linalg_check_batch_size(A.size(0), U.size(0), L.size(0));
+ if (A.dptr_ != U.dptr_) Copy(U, A, s);
+ // From here on, we work on U only
+ // Reserve workspaces (size determined by query)
+ int lwork(0), liwork(0);
+ linalg_syevd_workspace_query(U[0], &lwork, &liwork, s);
+ Tensor<xpu, 1, DType> work = ctx.requested[0]
+ .get_space_typed<xpu, 1, DType>(Shape1(lwork), s);
+ Tensor<xpu, 1, int> iwork = ctx.requested[0]
+ .get_space_typed<xpu, 1, int>(Shape1(liwork), s);
+ // Loop over items in batch
+ for (index_t i = 0; i < U.size(0); ++i) {
+ linalg_syevd(U[i], L[i], work, iwork, s);
+ }
+ // Set signs of eigenvectors in a deterministic way
+ using namespace mxnet_op;
+ Kernel<SyevdEigenVecSigns, xpu>::Launch
+ (s, U.size(0)*U.size(1), U.size(1), U.dptr_, U.stride_);
+ }
+};
+
// Backward operators (always using batch processing)
struct gemm_backward {
@@ -540,6 +598,79 @@ struct gelqf_backward {
}
};
+// Helper for syevd_backward. See technical report for details
+// Note: Could be parallelized more, but this is subdominant anyway
+template<typename DType>
+DType syevd_back_helper_eps(DType* X);
+
+template<> inline
+float syevd_back_helper_eps(float* X) {
+ return 1e-30;
+}
+
+template<> inline
+double syevd_back_helper_eps(double* X) {
+ return 1e-100;
+}
+
+struct SyevdBackHelper {
+ template<typename DType>
+ MSHADOW_XINLINE static void Map(int k, int n, DType* X, int ldx, DType* L,
+ int ldl, DType* dL, int lddl, DType* Y,
+ int ldy) {
+ const int offx(k*n*ldx);
+ const int offy(k*n*ldy);
+ const int offl(k*ldl);
+ const int offdl(k*lddl);
+ DType denom(0.0), elem(0.0);
+ const DType eps(syevd_back_helper_eps(X));
+ // Lower and upper triangle: Loop i > j
+ for (int i = 1; i < n; ++i) {
+ for (int j = 0; j < i; ++j) {
+ denom = L[offl+i] - L[offl+j]; // Must be >=0
+ if (denom < eps) denom = eps;
+ denom *= 2.0;
+ elem = (X[offx+i*ldx+j] - X[offx+j*ldx+i])/denom;
+ Y[offy+i*ldy+j] = Y[offy+j*ldy+i] = elem;
+ }
+ }
+ // Diagonal
+ for (int i = 0; i < n; ++i) {
+ Y[offy+i*(ldy+1)] = dL[offdl+i];
+ }
+ }
+};
+
+// Have to reserve temporary storage tempM, same shape as dA.
+// dA may overwrite dU
+struct syevd_backward {
+ template<typename xpu, typename DType>
+ static void op(const Tensor<xpu, 3, DType>& dU,
+ const Tensor<xpu, 2, DType>& dL,
+ const Tensor<xpu, 3, DType>& U,
+ const Tensor<xpu, 2, DType>& L,
+ const Tensor<xpu, 3, DType>& dA,
+ const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
+ // Backward of (U, L) = syevd(A):
+ // dA = U**T * SyevdBackHelper(dU * U**T, L, dL) * U
+ using namespace mxnet_op;
+ Stream<xpu> *s = ctx.get_stream<xpu>();
+ // Need temporal space, same shape as dA
+ Tensor<xpu, 3, DType> tempM = ctx.requested[0]
+ .get_space_typed<xpu, 3, DType>(dA.shape_, s);
+ // This copy is just to make sure there are no invalid values (NaN,
infinity) in
+ // tempM. gemm multiplies tempM with 0, instead of setting entries to 0.
+ Copy(tempM, dU, s);
+ gemm::op(dU, U, tempM, DType(1.0), DType(0.0), false, true, s);
+ // SyevdBackHelper: tempM => dA
+ Kernel<SyevdBackHelper, xpu>::Launch
+ (s, dA.size(0), dA.size(1), tempM.dptr_, tempM.stride_, L.dptr_,
+ L.stride_, dL.dptr_, dL.stride_, dA.dptr_, dA.stride_);
+ gemm::op(U, dA, tempM, DType(1.0), DType(0.0), true, false, s);
+ gemm::op(tempM, U, dA, DType(1.0), DType(0.0), false, false, s);
+ }
+};
+
} // namespace op
} // namespace mxnet
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index fa73584..265a8c8 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3725,24 +3725,24 @@ def test_laop():
data_in1_t = np.transpose(data_in1)
data_in2_t = np.transpose(data_in2)
res_gemm = 4. * np.dot(data_in1, data_in2) + 7. * data_in4
- test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha=4., beta=7.)
+ test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7.)
check_fw(test_gemm, [data_in1, data_in2, data_in4], [res_gemm])
if grad_check == 1:
check_grad(test_gemm, [data_in1, data_in2, data_in4])
res_gemm = 4. * np.dot(data_in1_t, data_in2_t) + 7. * data_in3
- test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha=4., beta=7.,
+ test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7.,
transpose_a=True, transpose_b=True)
check_fw(test_gemm, [data_in1, data_in2, data_in3], [res_gemm])
if grad_check == 1:
check_grad(test_gemm, [data_in1, data_in2, data_in3])
res_gemm = 4. * np.dot(data_in1_t, data_in1) + 7. * data_in3
- test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha=4., beta=7.,
+ test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7.,
transpose_a=True)
check_fw(test_gemm, [data_in1, data_in1, data_in3], [res_gemm])
if grad_check == 1:
check_grad(test_gemm, [data_in1, data_in1, data_in3])
res_gemm = 4. * np.dot(data_in1, data_in1_t) + 7. * data_in4
- test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha=4., beta=7.,
+ test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7.,
transpose_b=True)
check_fw(test_gemm, [data_in1, data_in1, data_in4], [res_gemm])
if grad_check == 1:
@@ -3754,30 +3754,30 @@ def test_laop():
c = rep_3x(data_in4, 2, 2)
r = 4. * np.dot(data_in1, data_in2) + 7. * data_in4
r = rep_3x(r, 2, 2)
- test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha=4., beta=7.)
+ test_gemm = mx.sym.linalg.gemm(data1, data2, data3, alpha=4., beta=7.)
check_fw(test_gemm, [a, b, c], [r])
if grad_check == 1:
check_grad(test_gemm, [a, b, c])
# Check gemm2 operator same way as gemm.
res_gemm = 4. * np.dot(data_in1, data_in2)
- test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha=4.)
+ test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4.)
check_fw(test_gemm, [data_in1, data_in2], [res_gemm])
if grad_check == 1:
check_grad(test_gemm, [data_in1, data_in2])
res_gemm = 4. * np.dot(data_in1_t, data_in2_t)
- test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha=4., transpose_a=True,
+ test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4., transpose_a=True,
transpose_b=True)
check_fw(test_gemm, [data_in1, data_in2], [res_gemm])
if grad_check == 1:
check_grad(test_gemm, [data_in1, data_in2])
res_gemm = 4. * np.dot(data_in1_t, data_in1)
- test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha=4., transpose_a=True)
+ test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4., transpose_a=True)
check_fw(test_gemm, [data_in1, data_in1], [res_gemm])
if grad_check == 1:
check_grad(test_gemm, [data_in1, data_in1])
res_gemm = 4. * np.dot(data_in1, data_in1_t)
- test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha=4., transpose_b=True)
+ test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4., transpose_b=True)
check_fw(test_gemm, [data_in1, data_in1], [res_gemm])
if grad_check == 1:
check_grad(test_gemm, [data_in1, data_in1])
@@ -3786,7 +3786,7 @@ def test_laop():
a = rep_3x(data_in1, 2, 3)
b = rep_3x(data_in2, 3, 2)
r = rep_3x(4. * np.dot(data_in1, data_in2), 2, 2)
- test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha=4.)
+ test_gemm = mx.sym.linalg.gemm2(data1, data2, alpha=4.)
check_fw(test_gemm, [a, b], [r])
if grad_check == 1:
check_grad(test_gemm, [a, b])
@@ -3799,33 +3799,33 @@ def test_laop():
# test potrf
# Note: Have to symmetrize input, for gradient test to work
res_potrf = np.sqrt(data_in)
- test_potrf = mx.sym.linalg_potrf(data1)
+ test_potrf = mx.sym.linalg.potrf(data1)
check_fw(test_potrf, [data_in], [res_potrf])
if grad_check == 1:
check_grad(test_potrf, [data_in])
# test potri
ones = mx.nd.ones(shape).asnumpy()
res_potri = np.divide(ones, data_in * data_in)
- test_potri = mx.sym.linalg_potri(data1)
+ test_potri = mx.sym.linalg.potri(data1)
check_fw(test_potri, [data_in], [res_potri])
if grad_check == 1:
check_grad(test_potri, [data_in])
# test trsm
trian_in = data_in * 7.
- test_trsm = mx.sym.linalg_trsm(data1, data2, alpha=7.)
+ test_trsm = mx.sym.linalg.trsm(data1, data2, alpha=7.)
check_fw(test_trsm, [trian_in, data_in], [ones])
if grad_check == 1:
check_grad(test_trsm, [trian_in,data_in])
# test trmm
trian_in = np.divide(ones, trian_in)
- test_trmm = mx.sym.linalg_trmm(data1, data2, alpha=7., transpose=True,
+ test_trmm = mx.sym.linalg.trmm(data1, data2, alpha=7., transpose=True,
rightside=True)
check_fw(test_trmm, [trian_in, data_in], [ones])
if grad_check == 1:
check_grad(test_trmm, [trian_in, data_in])
# test sumlogdiag
res_sumlogdiag = np.reshape(np.log(data_in), (4, 4))
- test_sumlogdiag = mx.sym.linalg_sumlogdiag(data1)
+ test_sumlogdiag = mx.sym.linalg.sumlogdiag(data1)
check_fw(test_sumlogdiag, [data_in], [res_sumlogdiag])
if grad_check == 1:
check_grad(test_sumlogdiag, [data_in])
@@ -3850,7 +3850,7 @@ def test_laop():
ident = np.eye(4)
# test potrf
- test_potrf = mx.sym.linalg_potrf(_make_symm_symbol(data1, ndims=4))
+ test_potrf = mx.sym.linalg.potrf(_make_symm_symbol(data1, ndims=4))
a = rep_3x(matrix, 4, 4)
r = rep_3x(trian, 4, 4)
check_fw(test_potrf, [a], [r])
@@ -3860,7 +3860,7 @@ def test_laop():
#test potri
data1_ltri = _make_lower_triangle_symm(
data1, ndims=4, m=4, dtype=dtype)
- test_potri = mx.sym.linalg_potri(data1_ltri)
+ test_potri = mx.sym.linalg.potri(data1_ltri)
a = rep_3x(trian, 4, 4)
r = rep_3x(inv, 4, 4)
check_fw(test_potri, [a], [r])
@@ -3868,7 +3868,7 @@ def test_laop():
check_grad(test_potri, [a])
# test trsm
- test_trsm = mx.sym.linalg_trsm(data1_ltri, data2, alpha=7.)
+ test_trsm = mx.sym.linalg.trsm(data1_ltri, data2, alpha=7.)
a = rep_3x(trian, 4, 4)
b = rep_3x(matrix, 4, 4)
r = rep_3x(7. * np.transpose(trian), 4, 4)
@@ -3876,14 +3876,14 @@ def test_laop():
if grad_check == 1:
check_grad(test_trsm, [a, b])
- test_trsm2 = mx.sym.linalg_trsm(
+ test_trsm2 = mx.sym.linalg.trsm(
data1_ltri, data2, alpha=-2., rightside=True, transpose=True)
r = rep_3x(-2. * trian, 4, 4)
check_fw(test_trsm2, [a, b], [r])
if grad_check == 1:
check_grad(test_trsm2, [a, b])
- test_trsm3 = mx.sym.linalg_trsm(
+ test_trsm3 = mx.sym.linalg.trsm(
data1_ltri, data2, alpha=0.5, transpose=True)
b = rep_3x(np.transpose(trian), 4, 4)
r = rep_3x(0.5 * ident, 4, 4)
@@ -3891,7 +3891,7 @@ def test_laop():
if grad_check == 1:
check_grad(test_trsm3, [a, b])
- test_trsm4 = mx.sym.linalg_trsm(
+ test_trsm4 = mx.sym.linalg.trsm(
data1_ltri, data2, alpha=-0.5, rightside=True)
b = rep_3x(trian, 4, 4)
r = rep_3x(-0.5 * ident, 4, 4)
@@ -3900,7 +3900,7 @@ def test_laop():
check_grad(test_trsm4, [a, b])
# test trmm
- test_trmm = mx.sym.linalg_trmm(
+ test_trmm = mx.sym.linalg.trmm(
data1_ltri, data2, alpha=7., transpose=True, rightside=True)
a = rep_3x(trian, 4, 4)
b = rep_3x(matrix, 4, 4)
@@ -3909,19 +3909,19 @@ def test_laop():
if grad_check == 1:
check_grad(test_trmm, [a, b])
- test_trmm2 = mx.sym.linalg_trmm(data1_ltri, data2, alpha=-2.)
+ test_trmm2 = mx.sym.linalg.trmm(data1_ltri, data2, alpha=-2.)
r = rep_3x(-2. * np.dot(trian, matrix), 4, 4)
check_fw(test_trmm2, [a, b], [r])
if grad_check == 1:
check_grad(test_trmm2, [a, b])
- test_trmm3 = mx.sym.linalg_trmm(data1_ltri, data2, rightside=True)
+ test_trmm3 = mx.sym.linalg.trmm(data1_ltri, data2, rightside=True)
r = rep_3x(np.dot(matrix, trian), 4, 4)
check_fw(test_trmm3, [a, b], [r])
if grad_check == 1:
check_grad(test_trmm3, [a, b])
- test_trmm4 = mx.sym.linalg_trmm(
+ test_trmm4 = mx.sym.linalg.trmm(
data1_ltri, data2, alpha=1.2, transpose=True)
r = rep_3x(1.2 * np.dot(trian.T, matrix), 4, 4)
check_fw(test_trmm4, [a, b], [r])
@@ -3936,27 +3936,35 @@ def test_laop():
check_grad(test_sumlogdiag, [a])
-# Tests for new operators linalg_syrk, linalg_gelqf
+# Tests for operators linalg.syrk, linalg.gelqf
def _gelqf_combined_symbol(a):
- q, l = mx.sym.linalg_gelqf(a)
- q_qt = mx.sym.linalg_syrk(q, transpose=False, alpha=1., name='Q_times_Qt')
- l_q = mx.sym.linalg_trmm(l, q, alpha=1., name='L_times_Q')
+ q, l = mx.sym.linalg.gelqf(a)
+ q_qt = mx.sym.linalg.syrk(q, transpose=False, alpha=1., name='Q_times_Qt')
+ l_q = mx.sym.linalg.trmm(l, q, alpha=1., name='L_times_Q')
return mx.sym.Group([q_qt, l_q])
# NOTE: If we leave the unused output dangling, things break if
dtype=np.float64. Namely, the
# backward gradient for the unused output is of dtype np.float32 then.
# ==> Very annoying!
def _gelqf_first_output(a):
- q, l = mx.sym.linalg_gelqf(a)
+ q, l = mx.sym.linalg.gelqf(a)
bogus_scal = mx.sym.sum(mx.sym.BlockGrad(l), axis=(), keepdims=True) * 0.0
return mx.sym.broadcast_add(q, bogus_scal)
def _gelqf_second_output(a):
- q, l = mx.sym.linalg_gelqf(a)
+ q, l = mx.sym.linalg.gelqf(a)
bogus_scal = mx.sym.sum(mx.sym.BlockGrad(q), axis=(), keepdims=True) * 0.0
return mx.sym.broadcast_add(l, bogus_scal)
+def _syevd_combined_symbol(a):
+ u, lam = mx.sym.linalg.syevd(a)
+ u_ut = mx.sym.linalg.syrk(u, transpose=False, alpha=1., name='U_times_Ut')
+ lam_u = mx.sym.broadcast_mul(mx.sym.reshape(lam, shape=(-2, 1)), u)
+ ut_lam_u = mx.sym.linalg.gemm2(u, lam_u, alpha=1., transpose_a=True,
+ transpose_b=False, name='Ut_L_U')
+ return mx.sym.Group([u_ut, ut_lam_u])
+
def test_laop_2():
np.random.seed(1896893923)
dtype = np.float64
@@ -3979,18 +3987,18 @@ def test_laop_2():
rep_3x = lambda a, m, n :\
np.reshape(np.tile(np.array(a).flatten(), 3), (3, 1, m, n))
- # Tests for linalg_syrk
+ # Tests for linalg.syrk
mnalpha_lst = [(2, 3, 1.), (5, 3, -2.), (1, 6, 5.), (3, 3, 0.5), (4, 1,
10.), (1, 1, 1.)]
for m, n, alpha in mnalpha_lst:
- #print('m={}, n={}, alpha={}'.format(m, n, alpha))
+ #print('syrk: m={}, n={}, alpha={}'.format(m, n, alpha))
data_in1 = np.random.uniform(1, 10, (m, n))
res_syrk1 = alpha * np.dot(data_in1, data_in1.T)
- test_syrk1 = mx.sym.linalg_syrk(data1, transpose=False, alpha=alpha)
+ test_syrk1 = mx.sym.linalg.syrk(data1, transpose=False, alpha=alpha)
check_fw(test_syrk1, [data_in1], [res_syrk1])
if grad_check == 1:
check_grad(test_syrk1, [data_in1])
res_syrk2 = alpha * np.dot(data_in1.T, data_in1)
- test_syrk2 = mx.sym.linalg_syrk(data1, transpose=True, alpha=alpha)
+ test_syrk2 = mx.sym.linalg.syrk(data1, transpose=True, alpha=alpha)
check_fw(test_syrk2, [data_in1], [res_syrk2])
if grad_check == 1:
check_grad(test_syrk2, [data_in1])
@@ -4005,18 +4013,18 @@ def test_laop_2():
if grad_check == 1:
check_grad(test_syrk2, [a_batch])
- # Tests for linalg_gelqf
+ # Tests for linalg.gelqf
# Currently disabled on GPU as they need cuda8
# and MxNet builds use cuda 7.5
- if default_context() != mx.cpu():
+ if not (default_context() == mx.cpu()):
return
-
+
test_gelqf2 = _gelqf_combined_symbol(data1) # Outputs (dot(Q, Q.T),
dot(L, Q))
test_gelqf_q = _gelqf_first_output(data1) # Output Q (L is not dangling)
test_gelqf_l = _gelqf_second_output(data1) # Output L (Q is not dangling)
mn_lst = [(4, 4), (1, 1), (5, 20), (1, 10), (15, 50)]
for m, n in mn_lst:
- #print('m={}, n={}'.format(m, n))
+ #print('gelqf: m={}, n={}'.format(m, n))
data_in1 = np.random.normal(0., 10., (m, n))
res_eye = np.eye(m)
res_a = data_in1
@@ -4038,6 +4046,144 @@ def test_laop_2():
check_grad(test_gelqf_l, [a_batch])
+# Tests for operator linalg.syevd
+
+def _syevd_first_output(a):
+ u, lam = mx.sym.linalg.syevd(a)
+ bogus_scal = mx.sym.sum(mx.sym.BlockGrad(lam), axis=(), keepdims=True) *
0.0
+ return mx.sym.broadcast_add(u, bogus_scal)
+
+def _syevd_second_output(a):
+ u, lam = mx.sym.linalg.syevd(a)
+ bogus_scal = mx.sym.sum(mx.sym.BlockGrad(u), axis=(), keepdims=True) * 0.0
+ return mx.sym.broadcast_add(lam, bogus_scal)
+
+def _syevd_forward(a):
+ lam, ut = np.linalg.eig(a)
+ ind = np.argsort(lam)
+ lam = lam[ind]
+ u = ut[:, ind].T
+ for i in range(0, a.shape[0]):
+ _syevd_forw_eigvec_sign(u[i])
+ return u, lam
+
+def _syevd_forw_eigvec_sign(v):
+ ind = np.argmax(np.abs(v))
+ if v[ind] < 0.:
+ v[:] = -v
+
+def _syevd_backward(grad_u, grad_l, u, l):
+ n = l.size
+ assert grad_l.size == n
+ assert grad_u.shape == (n, n)
+ assert u.shape == (n, n)
+ temp = np.dot(grad_u, u.T)
+ temp2 = np.diag(grad_l)
+ for i in range(1, n):
+ for j in range(0, i):
+ denom = 2. * (l[i] - l[j])
+ elem = (temp[i, j] - temp[j, i])/denom
+ temp2[i, j] = elem
+ temp2[j, i] = elem
+ temp3 = np.dot(u.T, temp2)
+ return np.dot(temp3, u)
+
+def test_laop_3():
+ # Operators implemented for CPU only currently
+ if not (default_context() == mx.cpu()):
+ return
+ np.random.seed(1896893923)
+ dtype = np.float64
+ rtol_fw = 1e-6
+ atol_fw = 1e-6
+ num_eps = 1e-4
+ rtol_bw = 1e-2
+ atol_bw = 1e-2
+ # enable numerical checking of gradients
+ grad_check = 1
+
+ data1 = mx.symbol.Variable('data1')
+
+ check_fw = lambda sym, location, expected :\
+ check_symbolic_forward(sym, location, expected, rtol=rtol_fw,
+ atol=atol_fw, dtype=dtype)
+ check_grad = lambda sym, location:\
+ check_numeric_gradient(sym, location, numeric_eps=num_eps,
rtol=rtol_bw,
+ atol=atol_bw, dtype=dtype)
+ rep_3x = lambda a, m, n :\
+ np.reshape(np.tile(np.array(a).flatten(), 3), (3, 1, m, n))
+ check_bw = lambda sym, location, out_grads, expected :\
+ check_symbolic_backward(sym, location, out_grads, expected,
+ rtol=rtol_fw, atol=atol_fw, dtype=dtype)
+
+ # Tests for linalg.syevd
+ test_syevd2 = _syevd_combined_symbol(data1) # Outputs (U U^T, U^T (diag
L) U)
+ data1_s2 = _make_symm_symbol(data1, ndims=2)
+ test_syevd_u_2 = _syevd_first_output(data1_s2)
+ test_syevd_l_2 = _syevd_second_output(data1_s2)
+ data1_s4 = _make_symm_symbol(data1, ndims=4)
+ test_syevd_u_4 = _syevd_first_output(data1_s4)
+ test_syevd_l_4 = _syevd_second_output(data1_s4)
+ n_lst = [4, 1, 2, 10, 14]
+ for n in n_lst:
+ #print('\n** syevd: n={}'.format(n))
+ data_in1 = np.random.normal(0., 10., (n, n))
+ data_in1 = 0.5 * (data_in1 + data_in1.T)
+ res_eye = np.eye(n)
+ res_a = data_in1
+ check_fw(test_syevd2, [data_in1], [res_eye, res_a])
+ # Check backward
+ grad_u = np.random.normal(0., 2., (n, n))
+ grad_l = np.random.normal(0., 2., (n,))
+ bw_u, bw_l = _syevd_forward(data_in1)
+ grad_a = _syevd_backward(grad_u, grad_l, bw_u, bw_l)
+ check_bw(mx.sym.linalg.syevd(data1), [data_in1], [grad_u, grad_l],
[grad_a])
+ if grad_check == 1:
+ # A => U
+ check_grad(test_syevd_u_2, [data_in1])
+ # A => L
+ check_grad(test_syevd_l_2, [data_in1])
+ # Batch mode (3x the same thing)
+ a_batch = rep_3x(data_in1, n, n)
+ reye_batch = rep_3x(res_eye, n, n)
+ ra_batch = a_batch
+ check_fw(test_syevd2, [a_batch], [reye_batch, ra_batch])
+ if grad_check == 1:
+ # A => U
+ check_grad(test_syevd_u_4, [a_batch])
+ # A => L
+ check_grad(test_syevd_l_4, [a_batch])
+
+
+# Note: Currently, linalg.syevd is activated for float64 only, due to the
issues
+# demonstrated by this unit test. For this reason, the second part of this test
+# (float32) is deactivated for now.
+def test_laop_4():
+ # Operators implemented for CPU only currently
+ if not(default_context() == mx.cpu()):
+ return
+ np.random.seed(1896893923)
+ rtol_fw = 1e-6
+ atol_fw = 1e-6
+
+ data1 = mx.symbol.Variable('data1')
+
+ check_fw = lambda sym, location, expected, dtype :\
+ check_symbolic_forward(sym, location, expected, rtol=rtol_fw,
+ atol=atol_fw, dtype=dtype)
+
+ a_np = np.array([[1., 2.], [2., 4.]])
+ u_np = np.array([[0.89442718, -0.44721359], [0.44721359, 0.89442718]])
+ l_np = np.array([0., 5.])
+ test_syevd = mx.sym.linalg.syevd(data1)
+ # float64
+ #print('float64')
+ check_fw(test_syevd, [a_np], [u_np, l_np], np.float64)
+ # float32
+ #print('float32')
+ #check_fw(test_syevd, [a_np], [u_np, l_np], np.float32)
+
+
def test_stack():
for _ in range(100):
ndim = random.randint(1, 5)
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].