ZiyueHuang commented on a change in pull request #8259: check_format of sparse 
ndrray
URL: https://github.com/apache/incubator-mxnet/pull/8259#discussion_r150372061
 
 

 ##########
 File path: src/common/utils.h
 ##########
 @@ -43,9 +43,177 @@
 #include <algorithm>
 #include <functional>
 
+#include "../operator/mxnet_op.h"
+
 namespace mxnet {
 namespace common {
 
+
+/*!
+ * \brief IndPtr should be non-negative, in non-decreasing order, start with 0
+ *           and end with value equal with size of indices.
+ */
+struct csr_indptr_check {
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, const IType* indptr,
+                                  const nnvm::dim_t end, const nnvm::dim_t 
idx_size) {
+    if (indptr[i+1] < 0 || indptr[i+1] < indptr[i] ||
+        (i == 0 && indptr[i] != 0) ||
+        (i == end - 1 && indptr[end] != idx_size))
+      *out = kCSRIndPtrErr;
+  }
+};
+
+/*!
+ *  \brief Indices should be non-negative, less than the number of columns
+ *           and in ascending order per row.
+ */
+struct csr_idx_check {
+  template<typename DType, typename IType, typename RType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, const IType* idx,
+                                  const RType* indptr, const nnvm::dim_t 
ncols) {
+    for (RType j = indptr[i]; j < indptr[i+1]; j++) {
+      if (idx[j] >= ncols || idx[j] < 0 ||
+          (j < indptr[i+1] - 1 && idx[j] >= idx[j+1])) {
+        *out = kCSRIdxErr;
+        break;
+      }
+    }
+  }
+};
+
+/*!
+ *  \brief Indices of RSPNDArray should be non-negative,
+ *           less than the size of first dimension and in ascending order
+ */
+struct rsp_idx_check {
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, const IType* idx,
+                                  const nnvm::dim_t end, const nnvm::dim_t 
nrows) {
+    if ((i < end && idx[i+1] <= idx[i])
+        || idx[i] < 0 || idx[i] >= nrows)
+      *out = kRSPIdxErr;
+  }
+};
+
+template<typename xpu>
+void CheckFormatWrapper(const RunContext &rctx, const NDArray &input,
+                        const TBlob &err_cpu, const bool full_check);
+
+/*!
+ * \brief Check the validity of CSRNDArray.
+ * \param rctx Execution context.
+ * \param input Input NDArray of CSRStorage.
+ * \param err_cpu Error number on cpu.
+ * \param full_check If true, rigorous check, O(N) operations,
+ *          otherwise basic check, O(1) operations.
+ */
+template<typename xpu>
+void CheckFormatCSRImpl(const RunContext &rctx, const NDArray &input,
+                        const TBlob &err_cpu, const bool full_check) {
+  using namespace op::mxnet_op;
+  CHECK_EQ(input.storage_type(), kCSRStorage)
+          << "CheckFormatCSRImpl is for CSRNDArray";
+  const TShape shape = input.shape();
+  const TShape idx_shape = input.aux_shape(csr::kIdx);
+  const TShape indptr_shape = input.aux_shape(csr::kIndPtr);
+  const TShape storage_shape = input.storage_shape();
+  if ((shape.ndim() != 2) ||
+      (idx_shape.ndim() != 1 || indptr_shape.ndim() != 1 || 
storage_shape.ndim() != 1) ||
+      (indptr_shape[0] != shape[0] + 1) ||
+      (idx_shape[0] != storage_shape[0])) {
+     MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, {
+       DType* err = err_cpu.dptr<DType>();
+       *err = kCSRShapeErr;
+     });
+     return;
+  }
+  if (full_check) {
+    MSHADOW_TYPE_SWITCH(err_cpu.type_flag_, DType, {
+      MSHADOW_IDX_TYPE_SWITCH(input.aux_type(csr::kIndPtr), RType, {
+        MSHADOW_IDX_TYPE_SWITCH(input.aux_type(csr::kIdx), IType, {
+          mshadow::Stream<xpu> *s = rctx.get_stream<xpu>();
+          NDArray ret_xpu = NDArray(mshadow::Shape1(1),
+                                    rctx.get_ctx(), false, err_cpu.type_flag_);
+          TBlob val_xpu = ret_xpu.data();
+          Kernel<set_to_int<kNormalErr>, xpu>::Launch(s, val_xpu.Size(), 
val_xpu.dptr<DType>());
 
 Review comment:
   This array, ret_xpu, is of shape (1, ) and just hold the err number. So it 
is set to kNormalErr initially. This array is created on the same context of 
source array and used for kernel launch.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to