haojin2 commented on a change in pull request #11326: [MXNET-381] Enhancement 
of take operator
URL: https://github.com/apache/incubator-mxnet/pull/11326#discussion_r198673354
 
 

 ##########
 File path: src/operator/tensor/indexing_op.h
 ##########
 @@ -321,10 +309,53 @@ struct Take {
   MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data,
                                   const IType* idx, const int M, const int K) {
     int j = static_cast<int>(idx[i/M]);
-    if (j <= 0) j = 0;
-    else if (j >= K) j = K - 1;
+    if (clip) {
+      if (j <= 0) j = 0;
+      else if (j >= K) j = K - 1;
+    } else {
+      j = j % K;
+      j += (j < 0) ? K : 0;
+    }
     out_data[i] = in_data[j * M + i % M];
   }
+
+  /*!
+   * \brief Map function for take operator
+   * \param i           global thread id
+   * \param out_data    ptr to output buffer
+   * \param in_data     ptr to input buffer
+   * \param idx         ptr to indices buffer
+   * \param in_ndims    # of dims of input tensor
+   * \param out_ndims   # of dims of output tensor
+   * \param idx_ndims   # of dims of indices tensor
+   * \param axis_dim    dim size of the axis dimension
+   * \param axis        axis id
+   */
+  template<typename DType, typename IType>
+  MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* 
in_data, const IType* idx,
+                                  const mshadow::Shape<10> in_stride,
+                                  const mshadow::Shape<10> out_stride,
+                                  const int in_ndims, const int out_ndims, 
const int idx_ndims,
+                                  const int axis_dim, const int axis) {
+    // i is the global flattened index in the output
+    const int out_head_index = (axis == 0) ? 0 : (i / out_stride[axis - 1]);
 
 Review comment:
   There's possibility of IType to be of a floating number type, so compiler 
will complain about it. That's also the reason why the legacy Map function 
above is also using a cast.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to