This is an automated email from the ASF dual-hosted git repository.
sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 16eb89b Add GPU-optimization for split op (#19131)
16eb89b is described below
commit 16eb89b71d9bdbbb63ce07b7ac80fc029e6b7061
Author: MoisesHer <[email protected]>
AuthorDate: Sun Oct 11 14:33:31 2020 -0700
Add GPU-optimization for split op (#19131)
* Add GPU-optimization for split op
* Complete operator
* unit-test: use parametrize
* fix lint
* fix lint
* fix lint
---
src/operator/tensor/matrix_op.cu | 218 +++++++++++++++++++++++++++++++++-
tests/python/gpu/test_operator_gpu.py | 18 +++
2 files changed, 235 insertions(+), 1 deletion(-)
diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu
index 1707693..d5e17da 100644
--- a/src/operator/tensor/matrix_op.cu
+++ b/src/operator/tensor/matrix_op.cu
@@ -137,6 +137,222 @@ void SliceDimTwoCsrImpl<gpu>(const mxnet::TShape &begin,
const mxnet::TShape &en
});
}
+template <typename DType>
+struct split_tensor_data {
+ static const int MaxSections = 128;
+ size_t num_sections;
+ DType* outputs[MaxSections];
+ size_t indices[MaxSections+1];
+ DType* inputs[1];
+};
+
+template <bool split_last_axis, typename LType, typename DType>
+__global__ void split_tensor_kernel(size_t input_size,
+ const split_tensor_data<DType> params,
+ size_t split_axis_size,
+ size_t tail_size,
+ size_t last_axis_size,
+ size_t blocks_last_axis) {
+ const int entries_per_load = sizeof(LType)/sizeof(DType);
+ const LType* in_aligned = reinterpret_cast<const LType*>(params.inputs[0]);
+ const size_t last_axis_size_aligned = entries_per_load > 0 ?
+ last_axis_size / entries_per_load :
last_axis_size;
+ if (split_last_axis) {
+ size_t input_offset_leading = (blockIdx.x / blocks_last_axis) *
last_axis_size_aligned;
+ size_t position_last_axis = (blockIdx.x % blocks_last_axis) * blockDim.x *
entries_per_load +
+ params.indices[0] + threadIdx.x *
entries_per_load;
+ if (position_last_axis < params.indices[params.num_sections]) {
+ size_t position_last_axis_aligned = entries_per_load > 0 ?
+ position_last_axis /
entries_per_load :
+ position_last_axis;
+ LType input_data = in_aligned[input_offset_leading +
position_last_axis_aligned];
+ // Binary search to find section of each thread
+ size_t lower = 0;
+ size_t upper = params.num_sections - 1;
+ while (lower < upper) {
+ size_t mid = (lower + upper + 1) / 2;
+ if (position_last_axis >= params.indices[mid])
+ lower = mid;
+ else
+ upper = mid - 1;
+ }
+ size_t section = upper;
+ size_t section_size = params.indices[section + 1] -
params.indices[section];
+ LType* out_aligned = reinterpret_cast<LType*>(params.outputs[section]);
+ size_t section_size_aligned = entries_per_load > 0 ? section_size /
entries_per_load :
+ section_size;
+ size_t index_aligned = entries_per_load > 0 ? params.indices[section] /
entries_per_load :
+ params.indices[section];
+ size_t output_offset_leading = (blockIdx.x / blocks_last_axis) *
section_size_aligned;
+ size_t output_position = output_offset_leading +
position_last_axis_aligned - index_aligned;
+ out_aligned[output_position] = input_data;
+ }
+ } else {
+ size_t split_axis_size_iter = params.indices[params.num_sections] -
params.indices[0];
+ size_t blocks_per_leading_dim = (split_axis_size_iter * tail_size *
blocks_last_axis);
+ // input offsets: leading (axes pre-split-axis), at split-axis, tail, and
blocks_last_axis
+ size_t input_offset_leading = (blockIdx.x / blocks_per_leading_dim) *
+ split_axis_size * tail_size *
last_axis_size_aligned;
+ size_t pos_in_split_axis = (blockIdx.x / (tail_size * blocks_last_axis)) %
+ split_axis_size_iter + params.indices[0];
+ size_t input_offset_split_axis = pos_in_split_axis * tail_size *
last_axis_size_aligned;
+ size_t offset_tail = ((blockIdx.x / blocks_last_axis) % tail_size) *
+ last_axis_size_aligned;
+ size_t input_offset = input_offset_leading + input_offset_split_axis +
offset_tail +
+ (blockIdx.x % blocks_last_axis) * blockDim.x;
+ // Binary search to find section for this block
+ size_t lower = 0;
+ size_t upper = params.num_sections - 1;
+ while (lower < upper) {
+ size_t mid = (lower + upper + 1) / 2;
+ if (pos_in_split_axis >= params.indices[mid])
+ lower = mid;
+ else
+ upper = mid - 1;
+ }
+ size_t section = upper;
+ size_t section_size = params.indices[section + 1] -
params.indices[section];
+ LType* out_aligned = reinterpret_cast<LType*>(params.outputs[section]);
+ // output offsets: leading (axes pre-split-axis), at split-axis,and
blocks_last_axis
+ size_t output_offset_leading = (blockIdx.x / blocks_per_leading_dim) *
+ section_size * tail_size *
last_axis_size_aligned;
+ size_t output_offset_split_axis = ((blockIdx.x % blocks_per_leading_dim) /
blocks_last_axis -
+ ((params.indices[section] -
params.indices[0]) * tail_size)) *
+ last_axis_size_aligned;
+ size_t output_offset = output_offset_leading + output_offset_split_axis +
+ (blockIdx.x % blocks_last_axis) * blockDim.x;
+ if (threadIdx.x < last_axis_size_aligned) {
+ LType input_data = in_aligned[input_offset + threadIdx.x];
+ out_aligned[output_offset + threadIdx.x] = input_data;
+ }
+ }
+}
+
+template <typename DType>
+int get_load_type_split(size_t last_axis_size,
+ bool splitting_last_axis,
+ size_t n_sections,
+ size_t* indices) {
+ using namespace mshadow;
+ int sections_largest_multiple = 8;
+ if (splitting_last_axis) {
+ for (size_t i = 0; i < n_sections; ++i) {
+ size_t size_section = indices[i+1] - indices[i];
+ if (size_section * sizeof(DType) % 8)
+ sections_largest_multiple = std::min(sections_largest_multiple, 4);
+ if (size_section * sizeof(DType) % 4)
+ sections_largest_multiple = std::min(sections_largest_multiple, 2);
+ if (size_section * sizeof(DType) % 2)
+ sections_largest_multiple = std::min(sections_largest_multiple, 1);
+ }
+ }
+ if (last_axis_size * sizeof(DType) % 8 == 0 && sections_largest_multiple ==
8) {
+ return kFloat64;
+ } else if (last_axis_size * sizeof(DType) % 4 == 0 &&
sections_largest_multiple >= 4) {
+ return kFloat32;
+ } else if (last_axis_size * sizeof(DType) % 2 == 0 &&
sections_largest_multiple >= 2) {
+ return kFloat16;
+ } else {
+ return kUint8;
+ }
+}
+
+inline void SplitOpForwardGPU(const nnvm::NodeAttrs& attrs,
+ const OpContext& ctx,
+ const std::vector<TBlob>& inputs,
+ const std::vector<OpReqType>& req,
+ const std::vector<TBlob>& outputs) {
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ using namespace mxnet_op;
+ const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
+ CHECK_EQ(inputs.size(), 1U);
+ CHECK_EQ(outputs.size(), (param.sections > 0) ? param.sections :
param.indices.ndim());
+ const TBlob& input_data = inputs[split_enum::kData];
+ int real_axis = param.axis;
+ if (real_axis < 0) {
+ real_axis += input_data.ndim();
+ }
+ size_t last_axis_size = input_data.shape_[inputs[0].ndim()-1];
+ size_t split_axis_size = input_data.shape_[real_axis];
+ size_t tail_size = 1; // does not include last dim
+ for (int i = real_axis + 1; i < input_data.ndim()-1; ++i) {
+ tail_size *= input_data.shape_[i];
+ }
+ if (last_axis_size < 128) {
+ // custom kernel will not be efficient with less than 128 elemnts in last
axis
+ SplitOpForwardImpl<gpu>(attrs, ctx, inputs, req, outputs, real_axis);
+ } else {
+ Stream<gpu> *s = ctx.get_stream<gpu>();
+ CHECK_LT(real_axis, input_data.ndim());
+ const mxnet::TShape& ishape = input_data.shape_;
+ const mxnet::TShape split_pts =
+ (param.sections > 0) ? GetSplitIndices(ishape, real_axis,
param.sections) : param.indices;
+ std::vector<size_t> indices;
+ for (const auto& split_pos : split_pts) {
+ indices.push_back(split_pos);
+ }
+ if (param.sections == 0) {
+ indices.push_back(ishape[real_axis]);
+ }
+ size_t n_sections = indices.size() - 1;
+ bool splitting_last_axis = (real_axis == inputs[0].ndim() - 1);
+
+ for (size_t sections_processed = 0; sections_processed < n_sections;) {
+ size_t remaining_sections = n_sections - sections_processed;
+ MSHADOW_TYPE_SWITCH(input_data.type_flag_, DType, {
+ // set parameters
+ split_tensor_data<DType> params{};
+ params.num_sections = std::min<size_t>(remaining_sections,
params.MaxSections);
+ params.inputs[0] = input_data.dptr<DType>();
+ for (size_t i = 0; i < params.num_sections; ++i) {
+ params.outputs[i] = outputs[sections_processed + i].dptr<DType>();
+ params.indices[i] = indices[sections_processed + i];
+ }
+ params.indices[params.num_sections] = indices[sections_processed +
params.num_sections];
+ // load type: we need to check that last axis size is multiple of ltype
+ // and if splitting_last_axis, all section sizes as well
+ int ltype = get_load_type_split<DType>(last_axis_size,
splitting_last_axis,
+ params.num_sections,
params.indices);
+ MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
+ CHECK_LE(sizeof(DType), sizeof(LType));
+ const size_t entries_per_load = sizeof(LType) / sizeof(DType);
+ size_t block_size = 32;
+ size_t max_threads_block = 256;
+ size_t last_axis_elements = entries_per_load > 0 ? (last_axis_size /
entries_per_load): 0;
+ if (splitting_last_axis) {
+ // may not be possible to include whole axis if too many sections
+ last_axis_elements = entries_per_load > 0 ?
+ ((params.indices[params.num_sections] - params.indices[0]) /
entries_per_load): 0;
+ }
+ while (block_size < last_axis_elements && (block_size <
max_threads_block)) {
+ block_size += 32;
+ }
+ size_t blocks_last_axis = (last_axis_elements + block_size - 1) /
block_size;
+ size_t n_blocks = blocks_last_axis;
+ for (int i = 0 ; i < input_data.ndim() - 1; ++i) {
+ if (i == real_axis) {
+ // may not be possible to include all sections if too many
+ n_blocks *= (params.indices[params.num_sections] -
params.indices[0]);
+ } else {
+ n_blocks *= input_data.shape_[i];
+ }
+ }
+ if (splitting_last_axis) {
+ split_tensor_kernel<true, LType><<<n_blocks, block_size, 0,
s->stream_>>>
+ (input_data.Size(), params, split_axis_size, tail_size,
+ last_axis_size, blocks_last_axis);
+ } else {
+ split_tensor_kernel<false, LType><<<n_blocks, block_size, 0,
s->stream_>>>
+ (input_data.Size(), params, split_axis_size, tail_size,
+ last_axis_size, blocks_last_axis);
+ }
+ });
+ sections_processed += params.num_sections;
+ });
+ }
+ }
+}
NNVM_REGISTER_OP(Reshape)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
@@ -219,7 +435,7 @@ NNVM_REGISTER_OP(space_to_depth)
.set_attr<FCompute>("FCompute<gpu>", SpaceToDepthOpForward<gpu>);
NNVM_REGISTER_OP(_split_v2)
-.set_attr<FCompute>("FCompute<gpu>", SplitOpForward<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", SplitOpForwardGPU);
NNVM_REGISTER_OP(_split_v2_backward)
.set_attr<FCompute>("FCompute<gpu>", SplitOpBackward<gpu>);
diff --git a/tests/python/gpu/test_operator_gpu.py
b/tests/python/gpu/test_operator_gpu.py
index a2c27a9..c9ab4c2 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -2319,3 +2319,21 @@ def test_fp16_spmm():
out = mxsps.dot(inp, weight)
out_np = mx.nd.dot(inp, weight)
assert_almost_equal(out.asnumpy(), out_np, rtol=1e-3, atol=1e-5)
+
+@with_seed()
[email protected]
[email protected]('dtype', ["float16", "float32", "float64"])
+def test_split_v2_fwd(dtype):
+ dim = random.randint(2, 9)
+ shape = rand_shape_nd(dim)
+ axis = random.randint(-dim, dim-1)
+ axis_size = shape[axis]
+ samples = random.randint(0, axis_size - 1)
+ indices = sorted(random.sample([i for i in range(1, axis_size)], samples))
+ indices = tuple(indices)
+ mx_data = rand_ndarray(shape, dtype=dtype)
+ np_data = mx_data.asnumpy()
+ np_out = np.split(np_data, indices_or_sections=indices, axis=axis)
+ data = mx.sym.Variable("data")
+ sym = mx.sym.split_v2(data, indices_or_sections=indices, axis=axis)
+ check_symbolic_forward(sym, {"data": mx_data}, np_out, rtol=1e-3,
atol=1e-5)