eric-haibin-lin commented on a change in pull request #8259: check_format of 
ndrray, mainly for csr
URL: https://github.com/apache/incubator-mxnet/pull/8259#discussion_r145605904
 
 

 ##########
 File path: src/common/utils.h
 ##########
 @@ -43,9 +43,88 @@
 #include <algorithm>
 #include <functional>
 
+#include "../operator/mxnet_op.h"
+#include "../ndarray/ndarray_function.h"
+
 namespace mxnet {
 namespace common {
 
+
+/*! 
+ * \brief IndPtr should be in non-decreasing order, start with 0
+ *           and end with value greater or equal than size of indices.
+ */
+struct indptr_check {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, mshadow::default_real_t* out, const 
DType* in,
+                                  const nnvm::dim_t end, const nnvm::dim_t 
idx_size) {
+    if ((in[i+1] < in[i]) || (i == 0 && in[i] != static_cast<DType>(0)) ||
+        (i == end && in[i] < static_cast<DType>(idx_size)))
+          *out = kCSRIndPtrErr;
+  }
+};
+
+/*!
+ *  \brief Indices should be less than the number of columns.
+ */
+struct idx_check {
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, mshadow::default_real_t* out,
+                                  const DType* in, const nnvm::dim_t ncols) {
+    if (in[i] >= static_cast<DType>(ncols)) *out = kCSRIdxErr;
+  }
+};
+
+template<typename xpu>
+void CheckFormatWrapper(const RunContext &rctx, const NDArray *input,
+                        TBlob *cpu_err, const bool &full_check);
+
+template<typename xpu>
+void CheckFormatImpl(const RunContext &rctx, const NDArray *input,
+                     TBlob *cpu_err, const bool &full_check) {
+  using namespace op::mxnet_op;
+  auto stype = input->storage_type();
+  auto err = cpu_err->dptr<mshadow::default_real_t>();
+  if (stype == kCSRStorage) {
+    const TShape shape = input->shape();
+    const TShape idx_shape = input->aux_shape(csr::kIdx);
+    const TShape indptr_shape = input->aux_shape(csr::kIndPtr);
+    const TShape storage_shape = input->storage_shape();
+    if ((shape.ndim() != 2) ||
+        (idx_shape.ndim() != 1 || indptr_shape.ndim() != 1 || 
storage_shape.ndim() != 1) ||
+        (indptr_shape[0] != shape[0] + 1) ||
+        (idx_shape[0] != storage_shape[0])) {
+          *err = kCSRShapeErr;
+          return;
+    }
+    if (full_check) {
+      NDArray xpu_ret = NDArray(mshadow::Shape1(1), rctx.get_ctx());
+      TBlob xpu_tmp = xpu_ret.data();
+      ndarray::Eval<xpu>(kNormalErr, &xpu_tmp, rctx);
+      int indptr_type = input->aux_type(csr::kIndPtr);
+      MSHADOW_TYPE_SWITCH(indptr_type, IType, {
+        Kernel<indptr_check, xpu>::Launch(
 
 Review comment:
   Since the user of this function may pass a TBlob of dtype = kInt32, 
kFloat32, etc, we should add a TYPE_SWITCH for the TBlob passed in. 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to