This is an automated email from the ASF dual-hosted git repository. bgawrych pushed a commit to branch take_opt in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit ad38228e1b334bf657f4133b903febc42f0d74d2 Author: Bartlomiej Gawrych <[email protected]> AuthorDate: Mon Nov 15 09:07:35 2021 +0100 Improve performance of take operator --- src/operator/tensor/indexing_op.cc | 94 ++++++++++++++++++++++++++------------ 1 file changed, 64 insertions(+), 30 deletions(-) diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 3082541..4602640 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -60,6 +60,46 @@ struct TakeZeroAxisCPU { } }; +template <bool clip = true> +struct TakeNonzeroAxisCPU { + /*! + * \brief Map function for take operator + * \param i global thread id + * \param out_data ptr to output buffer + * \param in_data ptr to input buffer + * \param idx ptr to indices buffer + * \param outer_dim_stride stride of dimension before axis + * \param axis_dim_stride stride of axis dimension + * \param idx_size size of the indices tensor + * \param axis_dim dim size of the axis dimension + * \param axis axis id + */ + template <typename DType, typename IType> + MSHADOW_XINLINE static void Map(index_t i, + DType* out_data, + const DType* in_data, + const IType* indices, + const index_t outer_dim_stride, + const index_t axis_dim_stride, + const int idx_size, + const int axis_dim, + const int axis) { + for (index_t j = 0; j < static_cast<index_t>(idx_size); ++j) { // 4 + int index = indices[j]; + if (clip) { + index = (index < 0) ? 0 : index; + index = (index > axis_dim - 1) ? (axis_dim - 1) : index; + } else { + index %= axis_dim; + index += (index < 0) ? axis_dim : 0; + } + size_t in_offset = i * outer_dim_stride + index * axis_dim_stride; + size_t out_offset = (i * idx_size + j) * axis_dim_stride; + memcpy(out_data + out_offset, in_data + in_offset, axis_dim_stride * sizeof(DType)); + } + } +}; + /* * \brief returns true if all indices are between [min, max] * \param data_ptr the indices to check @@ -323,6 +363,7 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs, const std::vector<OpReqType>& req, const std::vector<TBlob>& outputs) { using namespace mxnet_op; + if (req[take_::kOut] == kNullOp) return; const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed); @@ -375,39 +416,32 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs, for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) { in_strides[i] = stride; } - mshadow::Shape<10> out_strides; - stride = 1; - for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) { - out_strides[i] = stride; + int outer_dimensions = 1; + for (int i = 0; i < actual_axis; i++) { + outer_dimensions *= oshape[i]; } if (param.mode == take_::kClip) { - Kernel<TakeNonzeroAxis<true>, cpu>::Launch(s, - oshape.Size(), - outputs[take_::kOut].dptr<DType>(), - inputs[take_::kArr].dptr<DType>(), - inputs[take_::kIdx].dptr<IType>(), - out_strides[actual_axis - 1], - in_strides[actual_axis - 1], - in_strides[actual_axis], - arrshape.ndim(), - oshape.ndim(), - idxshape.ndim(), - arrshape[actual_axis], - actual_axis); + Kernel<TakeNonzeroAxisCPU<true>, cpu>::Launch(s, + outer_dimensions, + outputs[take_::kOut].dptr<DType>(), + inputs[take_::kArr].dptr<DType>(), + inputs[take_::kIdx].dptr<IType>(), + in_strides[actual_axis - 1], + in_strides[actual_axis], + idxshape.Size(), + arrshape[actual_axis], + actual_axis); } else { - Kernel<TakeNonzeroAxis<false>, cpu>::Launch(s, - oshape.Size(), - outputs[take_::kOut].dptr<DType>(), - inputs[take_::kArr].dptr<DType>(), - inputs[take_::kIdx].dptr<IType>(), - out_strides[actual_axis - 1], - in_strides[actual_axis - 1], - in_strides[actual_axis], - arrshape.ndim(), - oshape.ndim(), - idxshape.ndim(), - arrshape[actual_axis], - actual_axis); + Kernel<TakeNonzeroAxisCPU<false>, cpu>::Launch(s, + outer_dimensions, + outputs[take_::kOut].dptr<DType>(), + inputs[take_::kArr].dptr<DType>(), + inputs[take_::kIdx].dptr<IType>(), + in_strides[actual_axis - 1], + in_strides[actual_axis], + idxshape.Size(), + arrshape[actual_axis], + actual_axis); } } });
