eric-haibin-lin commented on a change in pull request #8259: check_format of ndrray, mainly for csr URL: https://github.com/apache/incubator-mxnet/pull/8259#discussion_r145605552
########## File path: src/common/utils.h ########## @@ -43,9 +43,88 @@ #include <algorithm> #include <functional> +#include "../operator/mxnet_op.h" +#include "../ndarray/ndarray_function.h" + namespace mxnet { namespace common { + +/*! + * \brief IndPtr should be in non-decreasing order, start with 0 + * and end with value greater or equal than size of indices. + */ +struct indptr_check { + template<typename DType> + MSHADOW_XINLINE static void Map(int i, mshadow::default_real_t* out, const DType* in, + const nnvm::dim_t end, const nnvm::dim_t idx_size) { + if ((in[i+1] < in[i]) || (i == 0 && in[i] != static_cast<DType>(0)) || + (i == end && in[i] < static_cast<DType>(idx_size))) + *out = kCSRIndPtrErr; + } +}; + +/*! + * \brief Indices should be less than the number of columns. + */ +struct idx_check { + template<typename DType> + MSHADOW_XINLINE static void Map(int i, mshadow::default_real_t* out, + const DType* in, const nnvm::dim_t ncols) { + if (in[i] >= static_cast<DType>(ncols)) *out = kCSRIdxErr; + } +}; + +template<typename xpu> Review comment: Let's add some documentation for this function, since this will be the one others use when they try to check the fmt of an NDArray in other backend cpp code. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services