haojin2 commented on a change in pull request #11326: [MXNET-381] Enhancement 
of take operator
URL: https://github.com/apache/incubator-mxnet/pull/11326#discussion_r198673785
 
 

 ##########
 File path: src/operator/tensor/indexing_op.h
 ##########
 @@ -744,37 +775,36 @@ struct TakeParam: public dmlc::Parameter<TakeParam> {
   }
 };
 
-template<typename PType>
-inline void TakeParamParser(nnvm::NodeAttrs *attrs) {
-    PType param;
-    param.Init(attrs->dict);
-    if (param.axis != 0) {
-        LOG(FATAL) << "Axis other than 0 currently not supported.";
-    }
-    if (param.mode != take_::kClip) {
-        LOG(FATAL) << "Mode other than clip currently not supported.";
-    }
-}
-
 inline bool TakeOpShape(const nnvm::NodeAttrs& attrs,
                         std::vector<TShape> *in_attrs,
                         std::vector<TShape> *out_attrs) {
-    using namespace mshadow;
-    const TShape &arrshape = (*in_attrs)[take_::kArr];
-    const TShape &idxshape = (*in_attrs)[take_::kIdx];
-    if (idxshape.ndim() == 0U || idxshape.Size() == 0U) return false;
+  using namespace mshadow;
+  const TShape &arrshape = (*in_attrs)[take_::kArr];
+  const TShape &idxshape = (*in_attrs)[take_::kIdx];
+  if (idxshape.ndim() == 0U || idxshape.Size() == 0U) return false;
+  const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
+  if (param.mode == take_::kRaise) {
+    LOG(FATAL) << "Raise is not supported for the time being...";
+  }
+  CHECK(param.axis >= -1 * (int)arrshape.ndim() && param.axis < 
(int)arrshape.ndim())
+    << "Axis should be in the range of [-r, r-1] where r is the rank of input 
tensor";
 
-    out_attrs->clear();
+  out_attrs->clear();
 
-    TShape oshape(idxshape.ndim() + arrshape.ndim() - 1);
-    for (size_t i = 0; i < idxshape.ndim(); ++i) {
-        oshape[i] = idxshape[i];
-    }
-    for (size_t i = 0; i < arrshape.ndim() - 1; i++) {
-        oshape[i + idxshape.ndim()] = arrshape[i + 1];
+  const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 
0);
 
 Review comment:
   Will do

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to