haojin2 commented on a change in pull request #15905: [Numpy] Basic indexing in 
symbolic interface
URL: https://github.com/apache/incubator-mxnet/pull/15905#discussion_r314582829
 
 

 ##########
 File path: src/operator/numpy/np_matrix_op.cc
 ##########
 @@ -135,17 +136,165 @@ bool NumpyReshapeInferShape(const mxnet::TShape& src, 
mxnet::TShape* dst) {
   }
 }
 
+bool NumpyXReshapeInferShape(const mxnet::TShape& src,
+                             const mxnet::Tuple<int>& target,
+                             mxnet::TShape* output) {
+  bool target_shape_is_known = true;
+  dim_t target_size = 1;
+  for (int i = 0; i < target.ndim(); ++i) {
+    if (target[i] < 0) {
+      target_shape_is_known = false;
+      target_size  = -1;
+      break;
+    } else {
+      target_size *= target[i];
+    }
+  }
+  if (shape_is_known(src) && target_shape_is_known) {
+    CHECK_EQ(src.Size(), target_size) << "Cannot reshape array of size "
+                                      << src.Size() << " into shape " << 
target;
+    *output = TShape(target.begin(), target.end());
+    return true;
+  } else if (!shape_is_known(src) || target.ndim() == -1) {
+    return false;
+  } else {
+    int unknown_axis = -1;
+    dim_t known_dim_size_prod = 1;
+    std::vector<dim_t> output_shape_vector;
+    int src_inx = 0;
+    for (int i = 0; i < target.ndim(); ++i) {
+      dim_t proposed_dim = target[i];
+      CHECK(proposed_dim >= -6)
+        << "Dimension size must be greater than -6, received " << proposed_dim;
+      if (proposed_dim == -1) {
+        // infer the known dimension
+        CHECK_LT(unknown_axis, 0)
+          << "One and only one dim can be inferred";
+        unknown_axis = output_shape_vector.size();
+        output_shape_vector.push_back(1);
+        src_inx++;
+      } else if (proposed_dim == -2) {
+        // copy the dimension from src to output
+        CHECK_LT(src_inx, src.ndim())
+          << "Unmatching dimension of proposed new shape";
+        known_dim_size_prod *= src[src_inx];
+        output_shape_vector.push_back(src[src_inx++]);
+      } else if (proposed_dim == -3) {
+        // skip the source dimension if and only if it is one
+        CHECK_EQ(src[src_inx], 1)
+          <<"-3 index should only be used to skip dimision size 1";
+        src_inx++;
+      } else if (proposed_dim == -4) {
+        // copy all remaining dims from source
+        while (src_inx < src.ndim()) {
+          known_dim_size_prod *= src[src_inx];
+          const int dn = src[src_inx++];
+          output_shape_vector.push_back(dn);
+        }
+      } else if (proposed_dim == -5) {
+        // merge two dims from source
+        CHECK_LT(src_inx, src.ndim()-1)
+          <<"Not enough dimensions left for the product";
+        const int d1 = src[src_inx++];
+        const int d2 = src[src_inx++];
+        if (!mxnet::dim_size_is_known(d1) || !mxnet::dim_size_is_known(d2)) {
+          CHECK_LT(unknown_axis, 0)
+          << "One and only one dim can be inferred";
+          unknown_axis = output_shape_vector.size();
+          output_shape_vector.push_back(-1);
+        } else {
+          known_dim_size_prod *= d1*d2;
+          output_shape_vector.push_back(d1 * d2);
+        }
+      } else if (proposed_dim == -6) {
+        // split the source dim s into two dims
+        // read the left dim and then the right dim (either can be -1)
+        CHECK_LT(i + 2, target.ndim());
+        CHECK_LT(src_inx, src.ndim());
+        const int d0 = src[src_inx++];
+        dim_t d1 = target[++i];
+        dim_t d2 = target[++i];
+        CHECK(d1 != -1 || d2 != -1) << "Split dims cannot both be -1.";
+        if (d1 == -1 && d0 >= 0) d1 = d0 / d2;  // d0 must be known to do this
+        if (d2 == -1 && d0 >= 0) d2 = d0 / d1;  // d0 must be known to do this
+        CHECK(d1 * d2 == static_cast<dim_t>(d0) || static_cast<dim_t>(d0) == 
dim_t(-1))
+          <<"Split dims " << d1 << ", " << d2 << " do not divide original dim 
" << d0;
+        if (d1 == -1) {
+          CHECK_LT(unknown_axis, 0)
+          << "One and only one dim can be inferred";
 
 Review comment:
   Same for all applicable places.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to