wenyangchu closed pull request #11247: [WIP] add seed_aug parameter for 
ImageRecordItr to fix random seed for augm…
URL: https://github.com/apache/incubator-mxnet/pull/11247
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/io/image_aug_default.cc b/src/io/image_aug_default.cc
index 22af7d92750..fb9c1030df8 100644
--- a/src/io/image_aug_default.cc
+++ b/src/io/image_aug_default.cc
@@ -80,6 +80,9 @@ struct DefaultImageAugmentParam : public 
dmlc::Parameter<DefaultImageAugmentPara
   int pad;
   /*! \brief shape of the image data*/
   TShape data_shape;
+  /*! \brief random seed for augmentations */
+  int seed_aug;
+
   // declare parameters
   DMLC_DECLARE_PARAMETER(DefaultImageAugmentParam) {
     DMLC_DECLARE_FIELD(resize).set_default(-1)
@@ -136,6 +139,8 @@ struct DefaultImageAugmentParam : public 
dmlc::Parameter<DefaultImageAugmentPara
     DMLC_DECLARE_FIELD(pad).set_default(0)
         .describe("Change size from ``[width, height]`` into "
                   "``[pad + width + pad, pad + height + pad]`` by padding 
pixes");
+    DMLC_DECLARE_FIELD(seed_aug).set_default(-1)
+        .describe("Random seed for augmentations. Default -1 does not set 
random seed.");
   }
 };
 
@@ -156,6 +161,7 @@ class DefaultImageAugmenter : public ImageAugmenter {
   // contructor
   DefaultImageAugmenter() {
     rotateM_ = cv::Mat(2, 3, CV_32F);
+    seed_init_state = false;
   }
   void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) 
override {
     std::vector<std::pair<std::string, std::string> > kwargs_left;
@@ -196,6 +202,10 @@ class DefaultImageAugmenter : public ImageAugmenter {
   }
   cv::Mat Process(const cv::Mat &src, std::vector<float> *label,
                   common::RANDOM_ENGINE *prnd) override {
+    if (!seed_init_state && param_.seed_aug > -1) {
+      prnd->seed(param_.seed_aug);
+      seed_init_state = true;
+    }
     using mshadow::index_t;
     cv::Mat res;
     if (param_.resize != -1) {
@@ -345,6 +355,7 @@ class DefaultImageAugmenter : public ImageAugmenter {
   DefaultImageAugmentParam param_;
   /*! \brief list of possible rotate angle */
   std::vector<int> rotate_list_;
+  bool seed_init_state;
 };
 
 ImageAugmenter* ImageAugmenter::Create(const std::string& name) {


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to