This is an automated email from the ASF dual-hosted git repository. apeforest pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 5950d8c Improving performance of broadcast_axis on GPU (#18168) 5950d8c is described below commit 5950d8c403d5966370a6f5a144a9e2137925c6ff Author: Rohit Kumar Srivastava <srivastava....@osu.edu> AuthorDate: Fri May 1 14:18:52 2020 -0700 Improving performance of broadcast_axis on GPU (#18168) * adding separate int32_t kernel for GPU in broadcast_axis/to/like operators * using structure instead of temp workspace to pass stride and shape * replacing hardcoded int32_t with generic index_t * combining CPU and GPU kernels to leverage cached stride calculation and fast access shape data in both Co-authored-by: Rohit Kumar Srivastava <srivastava....@buckeyemail.osu.edu> --- src/operator/numpy/np_matmul_op-inl.h | 7 +++- src/operator/tensor/broadcast_reduce_op.h | 67 ++++++++++++++++++++++++------- 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/src/operator/numpy/np_matmul_op-inl.h b/src/operator/numpy/np_matmul_op-inl.h index 89560f6..c1f0eed 100644 --- a/src/operator/numpy/np_matmul_op-inl.h +++ b/src/operator/numpy/np_matmul_op-inl.h @@ -157,12 +157,15 @@ inline void MatmulImpl(const OpContext& ctx, DType* bc_b_ptr = bc_a_ptr + bc_size_a; MSHADOW_TYPE_SWITCH_WITH_BOOL(input_a.type_flag_, IType, { MSHADOW_TYPE_SWITCH_WITH_BOOL(input_b.type_flag_, OType, { + struct ShapeAndStride aux_data_a, aux_data_b; + PrepareAUXData(&aux_data_a, k_a_shape, k_a_shape_bc, ndim); + PrepareAUXData(&aux_data_b, k_b_shape, k_b_shape_bc, ndim); Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch( s, bc_size_a, input_a.dptr<IType>(), bc_a_ptr, - k_a_shape, k_a_shape_bc, OpReqType::kWriteTo, ndim); + aux_data_a, OpReqType::kWriteTo, ndim); Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch( s, bc_size_b, input_b.dptr<IType>(), bc_b_ptr, - k_b_shape, k_b_shape_bc, OpReqType::kWriteTo, ndim); + aux_data_b, OpReqType::kWriteTo, ndim); }); }); ans = mshadow::Tensor<xpu, 3, DType>(output.dptr<DType>(), diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 03aa8b9..12af331 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -1049,29 +1049,66 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs, ReduceAxesBackwardUseInOutImpl<xpu, OP, normalize>(ctx, small, inputs, req, outputs); } +namespace { // unnamed namespace to keep scope of the struct within the file +struct ShapeAndStride { + index_t in_stride[MXNET_SPECIAL_MAX_NDIM]; + index_t out_stride[MXNET_SPECIAL_MAX_NDIM]; + index_t input_shape[MXNET_SPECIAL_MAX_NDIM]; + index_t output_shape[MXNET_SPECIAL_MAX_NDIM]; +}; + +/*! + * \brief Calculates Stride of input and output tensor dimesnions + And saves mshadow::Shape data in an integer array for + faster access. + * \param *aux_data to hold stride and shape data. + * \param in_shape input shape + * \param out_shape output shape + * \param ndim no of dimensions in output + */ +inline void PrepareAUXData(ShapeAndStride *aux_data, + mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape, + mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape, + int ndim) { + int iter = ndim - 1; + aux_data->out_stride[iter] = 1; + aux_data->in_stride[iter] = 1; + aux_data->input_shape[iter] = in_shape[iter]; + aux_data->output_shape[iter] = out_shape[iter]; + iter--; + for (; iter >= 0; --iter) { + aux_data->out_stride[iter] = aux_data->out_stride[iter + 1] * out_shape[iter + 1]; + aux_data->in_stride[iter] = aux_data->in_stride[iter + 1] * in_shape[iter + 1]; + aux_data->input_shape[iter] = in_shape[iter]; + aux_data->output_shape[iter] = out_shape[iter]; + } +} +} // unnamed namespace + template<typename OP> struct broadcast_kernel { template<typename IType, typename OType> MSHADOW_XINLINE static void Map(index_t i, IType *input, OType *output, - mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape, - mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape, + const ShapeAndStride& aux_data, const OpReqType req, - const uint32_t ndim) { - size_t in_stride = 1; - size_t out_stride = 1; + const int ndim) { index_t idx = i; index_t in_idx = i; +#pragma unroll 4 for (int iter = ndim - 1; iter >= 0; --iter) { - size_t dim_idx = idx % out_shape[iter]; - in_idx -= dim_idx * out_stride; - if (in_shape[iter] != 1) { - in_idx += dim_idx * in_stride; + index_t out_dim_shape = aux_data.output_shape[iter]; + index_t out_dim_stride = aux_data.out_stride[iter]; + // x % y = x - (x / y) * y + // speeds up modulo(%) operation in GPU + index_t dim_idx = idx - (idx / out_dim_shape) * out_dim_shape; + if (aux_data.input_shape[iter] != 1) { + in_idx += dim_idx * (aux_data.in_stride[iter] - out_dim_stride); + } else { + in_idx -= dim_idx * out_dim_stride; } - idx /= out_shape[iter]; - in_stride *= in_shape[iter]; - out_stride *= out_shape[iter]; + idx /= out_dim_shape; } KERNEL_ASSIGN(output[i], req, OP::Map(input[in_idx])); } @@ -1103,13 +1140,15 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, out_shape[i] = 1; } } + struct ShapeAndStride aux_data; + PrepareAUXData(&aux_data, in_shape, out_shape, dst_shape.ndim()); if (dst_shape.ndim() == 2) { Tensor<xpu, 2, OType> out = outputs[0].get_with_shape<xpu, 2, OType>(dst_shape.get<2>(), s); Tensor<xpu, 2, IType> data = inputs[0].get_with_shape<xpu, 2, IType>(src_shape.get<2>(), s); Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], 2); + s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], 2); } else { const int ndim = MXNET_SPECIAL_MAX_NDIM; Tensor<xpu, ndim, OType> out = @@ -1117,7 +1156,7 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, Tensor<xpu, ndim, IType> data = inputs[0].get_with_shape<xpu, ndim, IType>(src_shape.get<ndim>(), s); Kernel<broadcast_kernel<mshadow_op::identity>, xpu>::Launch( - s, out.shape_.Size(), data.dptr_, out.dptr_, in_shape, out_shape, req[0], ndim); + s, out.shape_.Size(), data.dptr_, out.dptr_, aux_data, req[0], ndim); } }); });