This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch vision in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit e97ef363bbae809f39c07415a6c6740b884c76f6 Author: Yizhi Liu <javeli...@gmail.com> AuthorDate: Sun Nov 26 17:20:55 2017 -0800 image flip op (#8759) * image flip op * rm image_common.h * fix * lint code * flip optimize --- src/operator/image/image_random-inl.h | 66 +++++++++++++++++++++++++++++++++-- src/operator/image/image_random.cc | 16 +++++++++ 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index ebbf60a..5c552b2 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -28,14 +28,20 @@ #include <mxnet/base.h> #include <algorithm> #include <vector> -#include <opencv2/opencv.hpp> -#include <opencv2/core/mat.hpp> +#include <algorithm> +#include <utility> #include "../mxnet_op.h" #include "../operator_common.h" namespace mxnet { namespace op { +inline bool CheckIsImage(const TBlob &image) { + CHECK_EQ(image.type_flag_, mshadow::kUint8) << "input type is not an image."; + CHECK_EQ(image.ndim(), 3) << "input dimension is not 3."; + CHECK(image.shape_[2] == 1 || image.shape_[2] == 3) << "image channel should be 1 or 3."; +} + static void RandomFlip(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector<TBlob> &inputs, @@ -76,6 +82,7 @@ static void ToTensor(const nnvm::NodeAttrs &attrs, const std::vector<TBlob> &outputs) { CHECK_EQ(req[0], kWriteTo) << "`to_tensor` does not support inplace"; + CheckIsImage(inputs[0]); int length = inputs[0].shape_[0] * inputs[0].shape_[1]; int channel = inputs[0].shape_[2]; @@ -101,7 +108,6 @@ struct NormalizeParam : public dmlc::Parameter<NormalizeParam> { } }; - inline bool NormalizeShape(const nnvm::NodeAttrs& attrs, std::vector<TShape> *in_attrs, std::vector<TShape> *out_attrs) { @@ -145,6 +151,60 @@ static void Normalize(const nnvm::NodeAttrs &attrs, }); } +struct FlipParam : public dmlc::Parameter<FlipParam> { + int axis; + DMLC_DECLARE_PARAMETER(FlipParam) { + DMLC_DECLARE_FIELD(axis) + .describe("0 or 1. 0 for horizontal flip, 1 for vertical flip."); + } +}; + +#define SWAP_IF_INPLACE(dst, dst_idx, src, src_idx) \ + if (dst == src) { \ + std::swap(dst[dst_idx], src[src_idx]); \ + } else { \ + dst[dst_idx] = src[src_idx]; \ + } + +template<typename DType> +static void FlipImpl(const TShape &shape, DType *src, DType *dst, int axis) { + const int height = shape[0]; + const int width = shape[1]; + const int nchannel = shape[2]; + + const int length = width * nchannel; + const int height_stride = (src == dst && axis == 1) ? (height >> 1) : height; + const int width_stride = (src == dst && axis == 0) ? (width >> 1) : width; + + for (int h = 0; h < height_stride; ++h) { + const int h_dst = (axis == 0) ? h : (height - h); + for (int w = 0; w < width_stride; ++w) { + const int w_dst = (axis == 0) ? (width - w) : w; + const int idx_dst = h_dst * length + w_dst * nchannel; + const int idx_src = h * length + w * nchannel; + SWAP_IF_INPLACE(dst, idx_dst, src, idx_src); + if (nchannel > 1) { + SWAP_IF_INPLACE(dst, idx_dst + 1, src, idx_src + 1); + SWAP_IF_INPLACE(dst, idx_dst + 2, src, idx_src + 2); + } + } + } +} + +static void Flip(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector<TBlob> &inputs, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &outputs) { + const FlipParam ¶m = nnvm::get<FlipParam>(attrs.parsed); + CHECK(param.axis == 0 || param.axis == 1) << "flip axis must be 0 or 1."; + CheckIsImage(inputs[0]); + const TShape& ishape = inputs[0].shape_; + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + FlipImpl(ishape, inputs[0].dptr<DType>(), outputs[0].dptr<DType>(), param.axis); + }); +} + struct RandomBrightnessParam : public dmlc::Parameter<RandomBrightnessParam> { float max_brightness; DMLC_DECLARE_PARAMETER(RandomBrightnessParam) { diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc index 5b47f50..4184382 100644 --- a/src/operator/image/image_random.cc +++ b/src/operator/image/image_random.cc @@ -59,6 +59,22 @@ NNVM_REGISTER_OP(_image_normalize) .add_argument("data", "NDArray-or-Symbol", "The input.") .add_arguments(NormalizeParam::__FIELDS__()); +DMLC_REGISTER_PARAMETER(FlipParam); +NNVM_REGISTER_OP(_image_flip) +.describe(R"code()code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser<FlipParam>) +.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) +.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) +.set_attr<nnvm::FInplaceOption>("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector<std::pair<int, int> >{{0, 0}}; + }) +.set_attr<FCompute>("FCompute<cpu>", Flip) +.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" }) +.add_argument("data", "NDArray-or-Symbol", "The input.") +.add_arguments(FlipParam::__FIELDS__()); DMLC_REGISTER_PARAMETER(RandomBrightnessParam); NNVM_REGISTER_OP(_image_random_brightness) -- To stop receiving notification emails like this one, please contact "comm...@mxnet.apache.org" <comm...@mxnet.apache.org>.