wenyangchu closed pull request #11247: [WIP] add seed_aug parameter for
ImageRecordItr to fix random seed for augm…
URL: https://github.com/apache/incubator-mxnet/pull/11247
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/src/io/image_aug_default.cc b/src/io/image_aug_default.cc
index 22af7d92750..fb9c1030df8 100644
--- a/src/io/image_aug_default.cc
+++ b/src/io/image_aug_default.cc
@@ -80,6 +80,9 @@ struct DefaultImageAugmentParam : public
dmlc::Parameter<DefaultImageAugmentPara
int pad;
/*! \brief shape of the image data*/
TShape data_shape;
+ /*! \brief random seed for augmentations */
+ int seed_aug;
+
// declare parameters
DMLC_DECLARE_PARAMETER(DefaultImageAugmentParam) {
DMLC_DECLARE_FIELD(resize).set_default(-1)
@@ -136,6 +139,8 @@ struct DefaultImageAugmentParam : public
dmlc::Parameter<DefaultImageAugmentPara
DMLC_DECLARE_FIELD(pad).set_default(0)
.describe("Change size from ``[width, height]`` into "
"``[pad + width + pad, pad + height + pad]`` by padding
pixes");
+ DMLC_DECLARE_FIELD(seed_aug).set_default(-1)
+ .describe("Random seed for augmentations. Default -1 does not set
random seed.");
}
};
@@ -156,6 +161,7 @@ class DefaultImageAugmenter : public ImageAugmenter {
// contructor
DefaultImageAugmenter() {
rotateM_ = cv::Mat(2, 3, CV_32F);
+ seed_init_state = false;
}
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs)
override {
std::vector<std::pair<std::string, std::string> > kwargs_left;
@@ -196,6 +202,10 @@ class DefaultImageAugmenter : public ImageAugmenter {
}
cv::Mat Process(const cv::Mat &src, std::vector<float> *label,
common::RANDOM_ENGINE *prnd) override {
+ if (!seed_init_state && param_.seed_aug > -1) {
+ prnd->seed(param_.seed_aug);
+ seed_init_state = true;
+ }
using mshadow::index_t;
cv::Mat res;
if (param_.resize != -1) {
@@ -345,6 +355,7 @@ class DefaultImageAugmenter : public ImageAugmenter {
DefaultImageAugmentParam param_;
/*! \brief list of possible rotate angle */
std::vector<int> rotate_list_;
+ bool seed_init_state;
};
ImageAugmenter* ImageAugmenter::Create(const std::string& name) {
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services