haojin2 commented on a change in pull request #11326: [MXNET-381] Enhancement of take operator URL: https://github.com/apache/incubator-mxnet/pull/11326#discussion_r198673354
########## File path: src/operator/tensor/indexing_op.h ########## @@ -321,10 +309,53 @@ struct Take { MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx, const int M, const int K) { int j = static_cast<int>(idx[i/M]); - if (j <= 0) j = 0; - else if (j >= K) j = K - 1; + if (clip) { + if (j <= 0) j = 0; + else if (j >= K) j = K - 1; + } else { + j = j % K; + j += (j < 0) ? K : 0; + } out_data[i] = in_data[j * M + i % M]; } + + /*! + * \brief Map function for take operator + * \param i global thread id + * \param out_data ptr to output buffer + * \param in_data ptr to input buffer + * \param idx ptr to indices buffer + * \param in_ndims # of dims of input tensor + * \param out_ndims # of dims of output tensor + * \param idx_ndims # of dims of indices tensor + * \param axis_dim dim size of the axis dimension + * \param axis axis id + */ + template<typename DType, typename IType> + MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, const IType* idx, + const mshadow::Shape<10> in_stride, + const mshadow::Shape<10> out_stride, + const int in_ndims, const int out_ndims, const int idx_ndims, + const int axis_dim, const int axis) { + // i is the global flattened index in the output + const int out_head_index = (axis == 0) ? 0 : (i / out_stride[axis - 1]); Review comment: There's possibility of IType to be of a floating number type, so compiler will complain about it. That's also the reason why the legacy Map function above is also using a cast. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services