This is an automated email from the ASF dual-hosted git repository.

sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 16eb89b  Add GPU-optimization for split op (#19131)
16eb89b is described below

commit 16eb89b71d9bdbbb63ce07b7ac80fc029e6b7061
Author: MoisesHer <[email protected]>
AuthorDate: Sun Oct 11 14:33:31 2020 -0700

    Add GPU-optimization for split op (#19131)
    
    * Add GPU-optimization for split op
    
    * Complete operator
    
    * unit-test: use parametrize
    
    * fix lint
    
    * fix lint
    
    * fix lint
---
 src/operator/tensor/matrix_op.cu      | 218 +++++++++++++++++++++++++++++++++-
 tests/python/gpu/test_operator_gpu.py |  18 +++
 2 files changed, 235 insertions(+), 1 deletion(-)

diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu
index 1707693..d5e17da 100644
--- a/src/operator/tensor/matrix_op.cu
+++ b/src/operator/tensor/matrix_op.cu
@@ -137,6 +137,222 @@ void SliceDimTwoCsrImpl<gpu>(const mxnet::TShape &begin, 
const mxnet::TShape &en
   });
 }
 
+template <typename DType>
+struct split_tensor_data {
+  static const int MaxSections = 128;
+  size_t num_sections;
+  DType* outputs[MaxSections];
+  size_t indices[MaxSections+1];
+  DType* inputs[1];
+};
+
+template <bool split_last_axis, typename LType, typename DType>
+__global__ void split_tensor_kernel(size_t input_size,
+                                    const split_tensor_data<DType> params,
+                                    size_t split_axis_size,
+                                    size_t tail_size,
+                                    size_t last_axis_size,
+                                    size_t blocks_last_axis) {
+  const int entries_per_load = sizeof(LType)/sizeof(DType);
+  const LType* in_aligned = reinterpret_cast<const LType*>(params.inputs[0]);
+  const size_t last_axis_size_aligned = entries_per_load > 0 ?
+                                        last_axis_size / entries_per_load : 
last_axis_size;
+  if (split_last_axis) {
+    size_t input_offset_leading = (blockIdx.x / blocks_last_axis) * 
last_axis_size_aligned;
+    size_t position_last_axis = (blockIdx.x % blocks_last_axis) * blockDim.x * 
entries_per_load +
+                                 params.indices[0] + threadIdx.x * 
entries_per_load;
+    if (position_last_axis < params.indices[params.num_sections]) {
+      size_t position_last_axis_aligned = entries_per_load > 0 ?
+                                          position_last_axis / 
entries_per_load :
+                                          position_last_axis;
+      LType input_data = in_aligned[input_offset_leading + 
position_last_axis_aligned];
+      // Binary search to find section of each thread
+      size_t lower = 0;
+      size_t upper = params.num_sections - 1;
+      while (lower < upper) {
+        size_t mid = (lower + upper + 1) / 2;
+        if (position_last_axis >=  params.indices[mid])
+          lower = mid;
+        else
+          upper = mid - 1;
+      }
+      size_t section = upper;
+      size_t section_size = params.indices[section + 1] - 
params.indices[section];
+      LType* out_aligned = reinterpret_cast<LType*>(params.outputs[section]);
+      size_t section_size_aligned = entries_per_load > 0 ? section_size / 
entries_per_load :
+                                                           section_size;
+      size_t index_aligned = entries_per_load > 0 ? params.indices[section] / 
entries_per_load :
+                                                    params.indices[section];
+      size_t output_offset_leading = (blockIdx.x / blocks_last_axis) * 
section_size_aligned;
+      size_t output_position = output_offset_leading + 
position_last_axis_aligned - index_aligned;
+      out_aligned[output_position] = input_data;
+    }
+  } else {
+    size_t split_axis_size_iter = params.indices[params.num_sections] - 
params.indices[0];
+    size_t blocks_per_leading_dim = (split_axis_size_iter * tail_size * 
blocks_last_axis);
+    // input offsets: leading (axes pre-split-axis), at split-axis, tail, and 
blocks_last_axis
+    size_t input_offset_leading = (blockIdx.x / blocks_per_leading_dim) *
+                                   split_axis_size * tail_size * 
last_axis_size_aligned;
+    size_t pos_in_split_axis = (blockIdx.x / (tail_size * blocks_last_axis)) %
+                               split_axis_size_iter + params.indices[0];
+    size_t input_offset_split_axis = pos_in_split_axis * tail_size * 
last_axis_size_aligned;
+    size_t offset_tail = ((blockIdx.x / blocks_last_axis) % tail_size) *
+                         last_axis_size_aligned;
+    size_t input_offset = input_offset_leading + input_offset_split_axis + 
offset_tail +
+                          (blockIdx.x % blocks_last_axis) * blockDim.x;
+    // Binary search to find section for this block
+    size_t lower = 0;
+    size_t upper = params.num_sections - 1;
+    while (lower < upper) {
+      size_t mid = (lower + upper + 1) / 2;
+      if (pos_in_split_axis >=  params.indices[mid])
+        lower = mid;
+      else
+        upper = mid - 1;
+    }
+    size_t section = upper;
+    size_t section_size = params.indices[section + 1] - 
params.indices[section];
+    LType* out_aligned = reinterpret_cast<LType*>(params.outputs[section]);
+    // output offsets: leading (axes pre-split-axis), at split-axis,and 
blocks_last_axis
+    size_t output_offset_leading = (blockIdx.x / blocks_per_leading_dim) *
+                                   section_size * tail_size * 
last_axis_size_aligned;
+    size_t output_offset_split_axis = ((blockIdx.x % blocks_per_leading_dim) / 
blocks_last_axis -
+                                      ((params.indices[section] - 
params.indices[0]) * tail_size)) *
+                                      last_axis_size_aligned;
+    size_t output_offset = output_offset_leading + output_offset_split_axis +
+                           (blockIdx.x % blocks_last_axis) * blockDim.x;
+    if (threadIdx.x < last_axis_size_aligned) {
+      LType input_data = in_aligned[input_offset + threadIdx.x];
+      out_aligned[output_offset + threadIdx.x] = input_data;
+    }
+  }
+}
+
+template <typename DType>
+int get_load_type_split(size_t last_axis_size,
+                        bool splitting_last_axis,
+                        size_t n_sections,
+                        size_t* indices) {
+  using namespace mshadow;
+  int sections_largest_multiple = 8;
+  if (splitting_last_axis) {
+    for (size_t i = 0; i < n_sections; ++i) {
+      size_t size_section = indices[i+1] - indices[i];
+      if (size_section * sizeof(DType) % 8)
+        sections_largest_multiple = std::min(sections_largest_multiple, 4);
+      if (size_section * sizeof(DType) % 4)
+        sections_largest_multiple = std::min(sections_largest_multiple, 2);
+      if (size_section * sizeof(DType) % 2)
+        sections_largest_multiple = std::min(sections_largest_multiple, 1);
+    }
+  }
+  if (last_axis_size * sizeof(DType) % 8 == 0 && sections_largest_multiple == 
8) {
+    return kFloat64;
+  } else if (last_axis_size * sizeof(DType) % 4 == 0 && 
sections_largest_multiple >= 4) {
+    return kFloat32;
+  } else if (last_axis_size * sizeof(DType) % 2 == 0 && 
sections_largest_multiple >= 2) {
+    return kFloat16;
+  } else {
+    return kUint8;
+  }
+}
+
+inline void SplitOpForwardGPU(const nnvm::NodeAttrs& attrs,
+                           const OpContext& ctx,
+                           const std::vector<TBlob>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), (param.sections > 0) ? param.sections : 
param.indices.ndim());
+  const TBlob& input_data = inputs[split_enum::kData];
+  int real_axis = param.axis;
+  if (real_axis < 0) {
+    real_axis += input_data.ndim();
+  }
+  size_t last_axis_size = input_data.shape_[inputs[0].ndim()-1];
+  size_t split_axis_size = input_data.shape_[real_axis];
+  size_t tail_size = 1;  // does not include last dim
+  for (int i = real_axis + 1; i < input_data.ndim()-1; ++i) {
+    tail_size *= input_data.shape_[i];
+  }
+  if (last_axis_size < 128) {
+    // custom kernel will not be efficient with less than 128 elemnts in last 
axis
+    SplitOpForwardImpl<gpu>(attrs, ctx, inputs, req, outputs, real_axis);
+  } else {
+    Stream<gpu> *s = ctx.get_stream<gpu>();
+    CHECK_LT(real_axis, input_data.ndim());
+    const mxnet::TShape& ishape = input_data.shape_;
+    const mxnet::TShape split_pts =
+      (param.sections > 0) ? GetSplitIndices(ishape, real_axis, 
param.sections) : param.indices;
+    std::vector<size_t> indices;
+    for (const auto& split_pos : split_pts) {
+      indices.push_back(split_pos);
+    }
+    if (param.sections == 0) {
+      indices.push_back(ishape[real_axis]);
+    }
+    size_t n_sections = indices.size() - 1;
+    bool splitting_last_axis = (real_axis == inputs[0].ndim() - 1);
+
+    for (size_t sections_processed = 0; sections_processed < n_sections;) {
+      size_t remaining_sections = n_sections - sections_processed;
+      MSHADOW_TYPE_SWITCH(input_data.type_flag_, DType, {
+        // set parameters
+        split_tensor_data<DType> params{};
+        params.num_sections = std::min<size_t>(remaining_sections, 
params.MaxSections);
+        params.inputs[0] = input_data.dptr<DType>();
+        for (size_t i = 0; i < params.num_sections; ++i) {
+          params.outputs[i] = outputs[sections_processed + i].dptr<DType>();
+          params.indices[i] = indices[sections_processed + i];
+        }
+        params.indices[params.num_sections] = indices[sections_processed + 
params.num_sections];
+        // load type: we need to check that last axis size is multiple of ltype
+        // and if splitting_last_axis, all section sizes as well
+        int ltype = get_load_type_split<DType>(last_axis_size, 
splitting_last_axis,
+                                               params.num_sections, 
params.indices);
+        MXNET_LOAD_TYPE_SWITCH(ltype, LType, {
+          CHECK_LE(sizeof(DType), sizeof(LType));
+          const size_t entries_per_load = sizeof(LType) / sizeof(DType);
+          size_t block_size = 32;
+          size_t max_threads_block = 256;
+          size_t last_axis_elements = entries_per_load > 0 ? (last_axis_size / 
entries_per_load): 0;
+          if (splitting_last_axis) {
+            // may not be possible to include whole axis if too many sections
+            last_axis_elements = entries_per_load > 0 ?
+                ((params.indices[params.num_sections] - params.indices[0]) / 
entries_per_load): 0;
+          }
+          while (block_size < last_axis_elements && (block_size < 
max_threads_block)) {
+            block_size += 32;
+          }
+          size_t blocks_last_axis = (last_axis_elements + block_size - 1) / 
block_size;
+          size_t n_blocks = blocks_last_axis;
+          for (int i = 0 ; i < input_data.ndim() - 1; ++i) {
+            if (i == real_axis) {
+              // may not be possible to include all sections if too many
+              n_blocks *= (params.indices[params.num_sections] - 
params.indices[0]);
+            } else {
+              n_blocks *= input_data.shape_[i];
+            }
+          }
+          if (splitting_last_axis) {
+            split_tensor_kernel<true, LType><<<n_blocks, block_size, 0, 
s->stream_>>>
+              (input_data.Size(), params, split_axis_size, tail_size,
+               last_axis_size, blocks_last_axis);
+          } else {
+            split_tensor_kernel<false, LType><<<n_blocks, block_size, 0, 
s->stream_>>>
+              (input_data.Size(), params, split_axis_size, tail_size,
+               last_axis_size, blocks_last_axis);
+          }
+        });
+        sections_processed += params.num_sections;
+      });
+    }
+  }
+}
 
 NNVM_REGISTER_OP(Reshape)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
@@ -219,7 +435,7 @@ NNVM_REGISTER_OP(space_to_depth)
 .set_attr<FCompute>("FCompute<gpu>", SpaceToDepthOpForward<gpu>);
 
 NNVM_REGISTER_OP(_split_v2)
-.set_attr<FCompute>("FCompute<gpu>", SplitOpForward<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", SplitOpForwardGPU);
 
 NNVM_REGISTER_OP(_split_v2_backward)
 .set_attr<FCompute>("FCompute<gpu>", SplitOpBackward<gpu>);
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index a2c27a9..c9ab4c2 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -2319,3 +2319,21 @@ def test_fp16_spmm():
     out = mxsps.dot(inp, weight)
     out_np = mx.nd.dot(inp, weight)
     assert_almost_equal(out.asnumpy(), out_np, rtol=1e-3, atol=1e-5)
+
+@with_seed()
[email protected]
[email protected]('dtype', ["float16", "float32", "float64"])
+def test_split_v2_fwd(dtype):
+    dim = random.randint(2, 9)
+    shape = rand_shape_nd(dim)
+    axis = random.randint(-dim, dim-1)
+    axis_size = shape[axis]
+    samples = random.randint(0, axis_size - 1)
+    indices = sorted(random.sample([i for i in range(1, axis_size)], samples))
+    indices = tuple(indices)
+    mx_data = rand_ndarray(shape, dtype=dtype)
+    np_data = mx_data.asnumpy()
+    np_out = np.split(np_data, indices_or_sections=indices, axis=axis)
+    data = mx.sym.Variable("data")
+    sym = mx.sym.split_v2(data, indices_or_sections=indices, axis=axis)
+    check_symbolic_forward(sym, {"data": mx_data}, np_out, rtol=1e-3, 
atol=1e-5)

Reply via email to