zheng-da commented on a change in pull request #13687: split_v2 operator
URL: https://github.com/apache/incubator-mxnet/pull/13687#discussion_r244906162
##########
File path: src/operator/tensor/matrix_op-inl.h
##########
@@ -2520,6 +2520,323 @@ void SpaceToDepthOpForward(const nnvm::NodeAttrs&
attrs,
});
}
+namespace split_enum {
+enum SplitOpInputs {kData};
+} // namespace split_enum
+
+struct SplitParam : public dmlc::Parameter<SplitParam> {
+ TShape indices;
+ int axis;
+ bool squeeze_axis;
+ int sections;
+ DMLC_DECLARE_PARAMETER(SplitParam) {
+ DMLC_DECLARE_FIELD(indices)
+ .describe("Indices of splits. The elements should denote the boundaries of
at which split"
+ " is performed along the `axis`.");
+ DMLC_DECLARE_FIELD(axis).set_default(1)
+ .describe("Axis along which to split.");
+ DMLC_DECLARE_FIELD(squeeze_axis).set_default(0)
+ .describe("If true, Removes the axis with length 1 from the shapes of the
output arrays."
+ " **Note** that setting `squeeze_axis` to ``true`` removes axis
with length 1"
+ " only along the `axis` which it is split."
+ " Also `squeeze_axis` can be set to ``true``"
+ " only if ``input.shape[axis] == num_outputs``.");
+ DMLC_DECLARE_FIELD(sections).set_default(0)
+ .describe("Number of sections if equally splitted. Default to 0 which
means split by indices.");
+ }
+}; // struct SplitParam
+
+inline TShape GetSplitIndices(const TShape& ishape, int axis, int sections) {
+ TShape indices(sections+1);
+ indices[0] = 0;
+ int64_t section_size = ishape[axis] / sections;
+ for (int i = 0; i < sections; ++i) {
+ indices[i+1] = section_size * (i + 1);
+ }
+ return indices;
+}
+
+inline bool SplitOpType(const nnvm::NodeAttrs& attrs,
+ std::vector<int>* in_attrs,
+ std::vector<int>* out_attrs) {
+ CHECK_EQ(in_attrs->size(), 1U);
+ int dtype = (*in_attrs)[0];
+ CHECK_NE(dtype, -1) << "First input must have specified type";
+ const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
+ out_attrs->clear();
+ int num_outputs = (param.sections > 0) ? param.sections :
param.indices.ndim();
+ for (int i = 0; i < num_outputs; ++i) {
+ out_attrs->push_back(dtype);
+ }
+ return true;
+}
+
+inline bool SplitOpShape(const nnvm::NodeAttrs& attrs,
+ std::vector<TShape>* in_attrs,
+ std::vector<TShape>* out_attrs) {
+ using namespace mshadow;
+ const SplitParam& param = nnvm::get<SplitParam>(attrs.parsed);
+ CHECK_EQ(in_attrs->size(), 1U);
+ TShape dshape = in_attrs->at(split_enum::kData);
+ TShape ishape = in_attrs->at(split_enum::kData);
+ if (dshape.ndim() == 0) return false;
+ if (param.axis >= 0) {
+ CHECK_LT(static_cast<size_t>(param.axis), dshape.ndim());
+ } else {
+ CHECK_LT(param.axis + dshape.ndim(), dshape.ndim());
+ }
+ int real_axis = param.axis;
+ if (real_axis < 0) {
+ real_axis += dshape.ndim();
+ }
+ const TShape indices =
+ (param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections)
: param.indices;
+ int num_outputs = (param.sections > 0) ? indices.ndim() - 1 : indices.ndim();
+ // Pre-compute squeezed output shape for future usage
+ TShape squeezed_dshape = dshape;
+ for (int d = real_axis; d < static_cast<int>(squeezed_dshape.ndim()) - 1;
++d) {
+ squeezed_dshape[d] = squeezed_dshape[d+1];
+ }
+ squeezed_dshape = TShape(&squeezed_dshape[0],
&squeezed_dshape[squeezed_dshape.ndim()-1]);
+ // Assign shape to every output
+ for (int i = 0; i < num_outputs; ++i) {
+ int start = indices[i];
+ int end = (i < num_outputs - 1) ? indices[i + 1] : ishape[real_axis];
+ CHECK(start < end)
+ << "start " << start << " is not less than end " << end << "for subarray
" << i;
+ CHECK(end <= ishape[real_axis])
+ << "end " << end << " is no less than the size of the axis " <<
ishape[real_axis];
+ dshape[real_axis] = (end - start);
+ if (param.squeeze_axis) {
+ CHECK_EQ(end - start, 1U) << "expected axis size of 1 but got " << end -
start;
+ SHAPE_ASSIGN_CHECK(*out_attrs, i, squeezed_dshape);
+ } else {
+ SHAPE_ASSIGN_CHECK(*out_attrs, i, dshape);
+ }
+ }
+ TShape back_calculate_dshape = ishape;
+ back_calculate_dshape[real_axis] = 0;
+ for (int d = 0; d < real_axis; ++d) {
+ back_calculate_dshape[d] = (*out_attrs)[0][d];
+ }
+ if (param.squeeze_axis) {
+ back_calculate_dshape[real_axis] = num_outputs;
+ } else {
+ for (int i = 0; i < num_outputs; ++i) {
+ back_calculate_dshape[real_axis] += (*out_attrs)[i][real_axis];
+ }
+ }
+ for (int d = real_axis + 1; d < static_cast<int>(ishape.ndim()); ++d) {
+ if (param.squeeze_axis) {
+ back_calculate_dshape[d] = (*out_attrs)[0][d - 1];
+ } else {
+ back_calculate_dshape[d] = (*out_attrs)[0][d];
+ }
+ }
+ SHAPE_ASSIGN_CHECK(*in_attrs, split_enum::kData, back_calculate_dshape);
+ return true;
+}
+
+struct SplitKernel {
+ /*!
+ * \brief Map function for split operator indices option
+ * \param i global thread id
+ * \param in_data ptr to input buffer
+ * \param out_data ptr to ptr of outputs buffer
+ * \param indices ptr to indices buffer
+ * \param num_sections # of sections after split
+ * \param axis_size size of axis to be splitted on
+ * \param trailing_size step size within the data buffer of the axis to be
splitted on
+ */
+ template<typename DType>
+ static MSHADOW_XINLINE void Map(size_t i,
+ const DType *in_data, DType** out_data,
const size_t* indices,
+ const size_t num_sections, const size_t
axis_size,
+ const size_t trailing_size) {
+ size_t idx = i / trailing_size % axis_size;
+ size_t target = 0;
+ for (size_t section = 0; section < num_sections; target = section++) {
+ if (indices[section] > idx) {
+ break;
+ }
+ }
+ DType* target_data = out_data[target];
+ const size_t mid_idx = idx - indices[target];
+ const size_t head_idx = i / (trailing_size * axis_size);
+ const size_t tail_idx = i % trailing_size;
+ const size_t section_size = indices[target + 1] - indices[target];
+ const size_t target_idx =
+ head_idx * trailing_size * section_size + mid_idx * trailing_size +
tail_idx;
+ target_data[target_idx] = in_data[i];
Review comment:
I'm a little concerned about this kernel. It takes a lot of computation to
copy an element from the original array to a destination array. At least we
should copy the entire row unless we split in the last dimension.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services