ThomasDelteil commented on a change in pull request #11928: Generalized 
reshape_like operator
URL: https://github.com/apache/incubator-mxnet/pull/11928#discussion_r206230304
 
 

 ##########
 File path: src/operator/tensor/elemwise_unary_op_basic.cc
 ##########
 @@ -350,10 +350,108 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs)
 .add_argument("lhs", "NDArray-or-Symbol", "First input.")
 .add_argument("rhs", "NDArray-or-Symbol", "Second input.");
 
+void ReshapeLikeRangeCanonicalize(int ndims, const char *side, int begin,
+                                  const dmlc::optional<int> &end, int *cbegin,
+                                  int *cend) {
+  *cbegin = begin;
+  if (*cbegin < 0)
+    *cbegin += ndims;
+
+  if (!static_cast<bool>(end)) {
+    *cend = ndims;
+  } else {
+    *cend = end.value();
+    if (*cend < 0) {
+      *cend += ndims;
+    }
+  }
+  CHECK(*cend <= ndims) << "Invalid end for " << side << "_end=" << end
+                        << " as dimension number is " << ndims;
+  CHECK((*cbegin < *cend)) << "Invalid begin, end, get " << side
+                           << "_begin=" << begin << ", " << side
+                           << "_end=" << end;
+
+  CHECK(*cend >= 0) << "Invalid end for " << side << "_end=" << end;
+  CHECK(*cbegin >= 0) << "Invalid begin for " << side << "_begin=" << begin;
+}
+
+void GetReshapeLikeParams(const ReshapeLikeParam &param, const TShape &lshape,
+                          const TShape &rshape, int *lhs_begin, int *lhs_end,
+                          int *rhs_begin, int *rhs_end) {
+  // LHS params
+  ReshapeLikeRangeCanonicalize(lshape.ndim(), "lhs", param.lhs_begin,
+                               param.lhs_end, lhs_begin, lhs_end);
+  // RHS params
+  ReshapeLikeRangeCanonicalize(rshape.ndim(), "rhs", param.rhs_begin,
+                               param.rhs_end, rhs_begin, rhs_end);
+}
 
+bool ReshapeLikeShapeCompute(const nnvm::NodeAttrs &attrs,
+                             std::vector<TShape> *in_attrs,
+                             std::vector<TShape> *out_attrs) {
+  const ReshapeLikeParam &param = nnvm::get<ReshapeLikeParam>(attrs.parsed);
+  const TShape &lshape = (*in_attrs)[0];
+  const TShape &rshape = (*in_attrs)[1];
+  int lhs_begin, lhs_end, rhs_begin, rhs_end;
+  GetReshapeLikeParams(param, lshape, rshape, &lhs_begin, &lhs_end, &rhs_begin,
+                       &rhs_end);
+
+  int lhsrank = static_cast<int>(lshape.ndim());
+  int orank = lhsrank + (rhs_end - rhs_begin) - (lhs_end - lhs_begin);
+  TShape oshape(orank);
+
+  for (int i = 0; i < lhs_begin; ++i)
+    oshape[i] = lshape[i];
+
+  int opos = lhs_begin;
+  for (int i = rhs_begin; i < rhs_end; ++i) {
+    oshape[opos] = rshape[i];
+    opos += 1;
+  }
+
+  for (int i = lhs_end; i < lhsrank; ++i) {
+    oshape[opos] = lshape[i];
+    opos += 1;
+  }
+
+  CHECK_EQ((*in_attrs)[0].Size(), oshape.Size())
+      << "Cannot reshape lhs with shape " << (*in_attrs)[0] << "to new "
+      << "shape " << oshape << " because they have different "
+      << "size.";
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
+  return true;
+}
+
+DMLC_REGISTER_PARAMETER(ReshapeLikeParam);
 NNVM_REGISTER_OP(reshape_like)
-.describe("Reshape lhs to have the same shape as rhs.")
+.describe(R"code(Reshape `lhs` to have the same shape as `rhs`.
 
 Review comment:
   I would alter this line to say something more along the line:
   "Reshape some or all dimensions of `lhs` to have the same shape as some or 
all dimensions of `rhs`"

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to