This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch vision
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/vision by this push:
     new f33654b  [Image] add random lighting (#8779)
f33654b is described below

commit f33654b13b9cf16da42fe6fe6fe3d1d3e3cfe779
Author: Xingjian Shi <xsh...@ust.hk>
AuthorDate: Wed Nov 22 15:53:37 2017 -0800

    [Image] add random lighting (#8779)
    
    * add random lighting
    
    * fix
---
 python/mxnet/gluon/data/vision/transforms.py    | 19 +++++
 src/operator/image/image_random-inl.h           | 95 +++++++++++++++++++++++++
 src/operator/image/image_random.cc              | 43 +++++++++--
 tests/python/unittest/test_gluon_data_vision.py | 40 +++++++++++
 4 files changed, 191 insertions(+), 6 deletions(-)

diff --git a/python/mxnet/gluon/data/vision/transforms.py 
b/python/mxnet/gluon/data/vision/transforms.py
index e1deef6..931d644 100644
--- a/python/mxnet/gluon/data/vision/transforms.py
+++ b/python/mxnet/gluon/data/vision/transforms.py
@@ -21,6 +21,7 @@ from .. import dataset
 from ...block import Block, HybridBlock
 from ...nn import Sequential, HybridSequential
 from .... import ndarray, initializer
+from ....base import _Null
 
 
 class Compose(Sequential):
@@ -151,3 +152,21 @@ class RandomColorJitter(HybridBlock):
 
     def hybrid_forward(self, F, x):
         return F.image.random_color_jitter(x, *self._args)
+
+
+class AdjustLighting(HybridBlock):
+    def __init__(self, alpha_rgb=_Null, eigval=_Null, eigvec=_Null):
+        super(AdjustLighting, self).__init__()
+        self._args = (alpha_rgb, eigval, eigvec)
+
+    def hybrid_forward(self, F, x):
+        return F.image.adjust_lighting(x, *self._args)
+
+
+class RandomLighting(HybridBlock):
+    def __init__(self, alpha_std=_Null, eigval=_Null, eigvec=_Null):
+        super(RandomLighting, self).__init__()
+        self._args = (alpha_std, eigval, eigvec)
+
+    def hybrid_forward(self, F, x):
+        return F.image.random_lighting(x, *self._args)
\ No newline at end of file
diff --git a/src/operator/image/image_random-inl.h 
b/src/operator/image/image_random-inl.h
index f823c8c..ebbf60a 100644
--- a/src/operator/image/image_random-inl.h
+++ b/src/operator/image/image_random-inl.h
@@ -26,6 +26,7 @@
 #define MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_
 
 #include <mxnet/base.h>
+#include <algorithm>
 #include <vector>
 #include <opencv2/opencv.hpp>
 #include <opencv2/core/mat.hpp>
@@ -290,11 +291,105 @@ static void RandomColorJitter(const nnvm::NodeAttrs 
&attrs,
                               const std::vector<TBlob> &outputs) {
 }
 
+struct AdjustLightingParam : public dmlc::Parameter<AdjustLightingParam> {
+  nnvm::Tuple<float> alpha_rgb;
+  nnvm::Tuple<float> eigval;
+  nnvm::Tuple<float> eigvec;
+  DMLC_DECLARE_PARAMETER(AdjustLightingParam) {
+    DMLC_DECLARE_FIELD(alpha_rgb)
+    .set_default({0, 0, 0})
+    .describe("The lighting alphas for the R, G, B channels.");
+    DMLC_DECLARE_FIELD(eigval)
+    .describe("Eigen value.")
+    .set_default({ 55.46, 4.794, 1.148 });
+    DMLC_DECLARE_FIELD(eigvec)
+    .describe("Eigen vector.")
+    .set_default({ -0.5675,  0.7192,  0.4009,
+                   -0.5808, -0.0045, -0.8140,
+                   -0.5808, -0.0045, -0.8140 });
+  }
+};
+
+struct RandomLightingParam : public dmlc::Parameter<RandomLightingParam> {
+  float alpha_std;
+  nnvm::Tuple<float> eigval;
+  nnvm::Tuple<float> eigvec;
+  DMLC_DECLARE_PARAMETER(RandomLightingParam) {
+    DMLC_DECLARE_FIELD(alpha_std)
+    .set_default(0.05)
+    .describe("Level of the lighting noise.");
+    DMLC_DECLARE_FIELD(eigval)
+    .describe("Eigen value.")
+    .set_default({ 55.46, 4.794, 1.148 });
+    DMLC_DECLARE_FIELD(eigvec)
+    .describe("Eigen vector.")
+    .set_default({ -0.5675,  0.7192,  0.4009,
+                   -0.5808, -0.0045, -0.8140,
+                   -0.5808, -0.0045, -0.8140 });
+  }
+};
+
+void AdjustLightingImpl(uint8_t* dst, const uint8_t* src,
+                        float alpha_r, float alpha_g, float alpha_b,
+                        const nnvm::Tuple<float> eigval, const 
nnvm::Tuple<float> eigvec,
+                        int H, int W) {
+    alpha_r *= eigval[0];
+    alpha_g *= eigval[1];
+    alpha_b *= eigval[2];
+    float pca_r = alpha_r * eigvec[0] + alpha_g * eigvec[1] + alpha_b * 
eigvec[2];
+    float pca_g = alpha_r * eigvec[3] + alpha_g * eigvec[4] + alpha_b * 
eigvec[5];
+    float pca_b = alpha_r * eigvec[6] + alpha_g * eigvec[7] + alpha_b * 
eigvec[8];
+    for (int i = 0; i < H * W; i++) {
+        int base_ind = 3 * i;
+        float in_r = static_cast<float>(src[base_ind]);
+        float in_g = static_cast<float>(src[base_ind + 1]);
+        float in_b = static_cast<float>(src[base_ind + 2]);
+        dst[base_ind] = std::min(255, std::max(0, static_cast<int>(in_r + 
pca_r)));
+        dst[base_ind + 1] = std::min(255, std::max(0, static_cast<int>(in_g + 
pca_g)));
+        dst[base_ind + 2] = std::min(255, std::max(0, static_cast<int>(in_b + 
pca_b)));
+    }
+}
+
+static void AdjustLighting(const nnvm::NodeAttrs &attrs,
+                           const OpContext &ctx,
+                           const std::vector<TBlob> &inputs,
+                           const std::vector<OpReqType> &req,
+                           const std::vector<TBlob> &outputs) {
+    using namespace mshadow;
+    const AdjustLightingParam &param = 
nnvm::get<AdjustLightingParam>(attrs.parsed);
+    CHECK_EQ(param.eigval.ndim(), 3) << "There should be 3 numbers in the 
eigval.";
+    CHECK_EQ(param.eigvec.ndim(), 9) << "There should be 9 numbers in the 
eigvec.";
+    CHECK_EQ(inputs[0].ndim(), 3);
+    CHECK_EQ(inputs[0].size(2), 3);
+    int H = inputs[0].size(0);
+    int W = inputs[0].size(1);
+    AdjustLightingImpl(outputs[0].dptr<uint8_t>(), inputs[0].dptr<uint8_t>(),
+                       param.alpha_rgb[0], param.alpha_rgb[1], 
param.alpha_rgb[2],
+                       param.eigval, param.eigvec, H, W);
+}
+
 static void RandomLighting(const nnvm::NodeAttrs &attrs,
                            const OpContext &ctx,
                            const std::vector<TBlob> &inputs,
                            const std::vector<OpReqType> &req,
                            const std::vector<TBlob> &outputs) {
+    using namespace mshadow;
+    const RandomLightingParam &param = 
nnvm::get<RandomLightingParam>(attrs.parsed);
+    CHECK_EQ(param.eigval.ndim(), 3) << "There should be 3 numbers in the 
eigval.";
+    CHECK_EQ(param.eigvec.ndim(), 9) << "There should be 9 numbers in the 
eigvec.";
+    CHECK_EQ(inputs[0].ndim(), 3);
+    CHECK_EQ(inputs[0].size(2), 3);
+    int H = inputs[0].size(0);
+    int W = inputs[0].size(1);
+    Stream<cpu> *s = ctx.get_stream<cpu>();
+    Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
+    std::normal_distribution<float> dist(0, param.alpha_std);
+    float alpha_r = dist(prnd->GetRndEngine());
+    float alpha_g = dist(prnd->GetRndEngine());
+    float alpha_b = dist(prnd->GetRndEngine());
+    AdjustLightingImpl(outputs[0].dptr<uint8_t>(), inputs[0].dptr<uint8_t>(),
+                       alpha_r, alpha_g, alpha_b,
+                       param.eigval, param.eigvec, H, W);
 }
 
 
diff --git a/src/operator/image/image_random.cc 
b/src/operator/image/image_random.cc
index 7ff7328..5b47f50 100644
--- a/src/operator/image/image_random.cc
+++ b/src/operator/image/image_random.cc
@@ -35,9 +35,6 @@ NNVM_REGISTER_OP(_image_to_tensor)
 .describe(R"code()code" ADD_FILELINE)
 .set_num_inputs(1)
 .set_num_outputs(1)
-.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
-  return std::vector<ResourceRequest>{ResourceRequest::kRandom};
-})
 .set_attr<nnvm::FInferShape>("FInferShape", ToTensorShape)
 .set_attr<nnvm::FInferType>("FInferType", ToTensorType)
 .set_attr<FCompute>("FCompute<cpu>", ToTensor)
@@ -51,9 +48,6 @@ NNVM_REGISTER_OP(_image_normalize)
 .set_num_inputs(1)
 .set_num_outputs(1)
 .set_attr_parser(ParamParser<NormalizeParam>)
-.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
-  return std::vector<ResourceRequest>{ResourceRequest::kRandom};
-})
 .set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption",
@@ -126,5 +120,42 @@ NNVM_REGISTER_OP(_image_random_saturation)
 .add_argument("data", "NDArray-or-Symbol", "The input.")
 .add_arguments(RandomSaturationParam::__FIELDS__());
 
+DMLC_REGISTER_PARAMETER(AdjustLightingParam);
+NNVM_REGISTER_OP(_image_adjust_lighting)
+.describe(R"code(Adjust the lighting level of the input. Follow the AlexNet 
style.)code" ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<AdjustLightingParam>)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<FCompute>("FCompute<cpu>", AdjustLighting)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
+.add_argument("data", "NDArray-or-Symbol", "The input.")
+.add_arguments(AdjustLightingParam::__FIELDS__());
+
+DMLC_REGISTER_PARAMETER(RandomLightingParam);
+NNVM_REGISTER_OP(_image_random_lighting)
+.describe(R"code(Randomly add PCA noise. Follow the AlexNet style.)code" 
ADD_FILELINE)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<RandomLightingParam>)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
+  return std::vector<ResourceRequest>{ResourceRequest::kRandom};
+})
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<FCompute>("FCompute<cpu>", RandomLighting)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
+.add_argument("data", "NDArray-or-Symbol", "The input.")
+.add_arguments(RandomLightingParam::__FIELDS__());
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_gluon_data_vision.py 
b/tests/python/unittest/test_gluon_data_vision.py
new file mode 100644
index 0000000..0c9e5c1
--- /dev/null
+++ b/tests/python/unittest/test_gluon_data_vision.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import print_function
+import mxnet as mx
+import mxnet.ndarray as nd
+import numpy as np
+from mxnet import gluon
+from mxnet.gluon.data.vision.transforms import AdjustLighting
+from mxnet.test_utils import assert_almost_equal
+
+def test_adjust_lighting():
+    data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
+    alpha_rgb = [0.05, 0.06, 0.07]
+    eigval = np.array([55.46, 4.794, 1.148])
+    eigvec = np.array([[-0.5675, 0.7192, 0.4009],
+                       [-0.5808, -0.0045, -0.8140],
+                       [-0.5808, -0.0045, -0.8140]])
+    f = AdjustLighting(alpha_rgb=alpha_rgb, eigval=eigval.ravel().tolist(), 
eigvec=eigvec.ravel().tolist())
+    out_nd = f(nd.array(data_in, dtype=np.uint8))
+    out_gt = np.clip(data_in.astype(np.float32)
+                     + np.dot(eigvec * alpha_rgb, eigval.reshape((3, 
1))).reshape((1, 1, 3)), 0, 255).astype(np.uint8)
+    assert_almost_equal(out_nd.asnumpy(), out_gt)
+
+if __name__ == '__main__':
+    import nose
+    nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to