zhreshold commented on a change in pull request #17841:
URL: https://github.com/apache/incubator-mxnet/pull/17841#discussion_r416867086



##########
File path: src/io/iter_sampler.cc
##########
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2020 by Contributors
+ * \file iter_sampler.cc
+ * \brief The sampler iterator for access dataset elements.
+ */
+#include <dmlc/parameter.h>
+#include <mshadow/random.h>
+#include <mxnet/io.h>
+#include <mxnet/base.h>
+#include <mxnet/resource.h>
+#include <numeric>
+#include "../common/utils.h"
+#include "./iter_batchloader.h"
+#include "./iter_prefetcher.h"
+
+namespace mxnet {
+namespace io {
+struct SequentialSamplerParam : public dmlc::Parameter<SequentialSamplerParam> 
{
+  /*! \brief Length of the sequence. */
+  size_t length;
+  /*! \brief Random seed.*/
+  int start;
+  // declare parameters
+  DMLC_DECLARE_PARAMETER(SequentialSamplerParam) {
+      DMLC_DECLARE_FIELD(length)
+          .describe("Length of the sequence.");
+      DMLC_DECLARE_FIELD(start).set_default(0)
+          .describe("Start of the index.");
+  }
+};  // struct SequentialSamplerParam
+
+DMLC_REGISTER_PARAMETER(SequentialSamplerParam);
+
+class SequentialSampler : public IIterator<DataInst> {
+ public:
+  virtual void Init(const std::vector<std::pair<std::string, std::string> >& 
kwargs) {
+    param_.InitAllowUnknown(kwargs);
+    indices_.resize(param_.length);
+    std::iota(std::begin(indices_), std::end(indices_), 0);  // fill like 
arange
+    out_.data.resize(2);  // label required by DataBatch, we can use fake 
label here
+    out_.data[1] = TBlob(indices_.data(), TShape({1, }), cpu::kDevMask, 0);
+  }
+
+  virtual void BeforeFirst(void) {
+    pos_ = 0;
+  }
+
+  virtual int64_t GetLenHint(void) const {
+    return static_cast<int64_t>(indices_.size());
+  }
+
+  virtual bool Next(void) {
+    if (pos_ < indices_.size()) {
+      int64_t *ptr = indices_.data() + pos_;
+      out_.data[0] = TBlob(ptr, TShape({1, }), cpu::kDevMask, 0);
+      ++pos_;
+      return true;
+    }
+    return false;
+  }
+
+  virtual const DataInst &Value(void) const {
+    return out_;
+  }
+
+ private:
+  /*! \brief Stored integer indices */
+  std::vector<int64_t> indices_;
+  /*! \brief current position for iteration */
+  std::size_t pos_;
+  /*! \brief data for next value */
+  DataInst out_;
+  /*! \brief arguments */
+  SequentialSamplerParam param_;
+};  // class SequentialSampler
+
+MXNET_REGISTER_IO_ITER(SequentialSampler)
+.describe(R"code(Returns the sequential sampler iterator.
+)code" ADD_FILELINE)
+.add_arguments(SequentialSamplerParam::__FIELDS__())
+.add_arguments(BatchSamplerParam::__FIELDS__())
+.set_body([]() {
+    return
+        new BatchSampler(
+            new SequentialSampler());
+  });
+
+struct RandomSamplerParam : public dmlc::Parameter<RandomSamplerParam> {
+  /*! \brief Length of the sequence. */
+  size_t length;
+  /*! \brief Random seed.*/
+  int seed;
+  // declare parameters
+  DMLC_DECLARE_PARAMETER(RandomSamplerParam) {
+      DMLC_DECLARE_FIELD(length)
+          .describe("Length of the sequence.");
+      DMLC_DECLARE_FIELD(seed).set_default(0)
+          .describe("Random seed.");
+  }
+};  // struct RandomSamplerParam
+
+DMLC_REGISTER_PARAMETER(RandomSamplerParam);
+
+class RandomSampler : public IIterator<DataInst> {
+ public:
+  virtual void Init(const std::vector<std::pair<std::string, std::string> >& 
kwargs) {
+    param_.InitAllowUnknown(kwargs);
+    indices_.resize(param_.length);
+    std::iota(std::begin(indices_), std::end(indices_), 0);  // fill like 
arange
+    rng_.reset(new common::RANDOM_ENGINE(kRandMagic + param_.seed));
+    out_.data.resize(2);  // label required by DataBatch, we can use fake 
label here
+    out_.data[1] = TBlob(indices_.data(), TShape({1, }), cpu::kDevMask, 0);
+    BeforeFirst();
+  }
+
+  virtual void BeforeFirst(void) {
+    std::shuffle(std::begin(indices_), std::end(indices_), *rng_);
+    pos_ = 0;
+  }
+
+  virtual int64_t GetLenHint(void) const {
+    return static_cast<int64_t>(indices_.size());
+  }
+
+  virtual bool Next(void) {
+    if (pos_ < indices_.size()) {
+      int64_t *ptr = indices_.data() + pos_;
+      out_.data[0] = TBlob(ptr, TShape({1, }), cpu::kDevMask, 0);
+      ++pos_;
+      return true;
+    }
+    return false;
+  }
+
+  virtual const DataInst &Value(void) const {
+    return out_;
+  }
+ private:
+  /*! \brief random magic number */
+  static const int kRandMagic = 2333;

Review comment:
       See 
https://github.com/apache/incubator-mxnet/pull/17841#discussion_r416847315
   
   > It's sharing the same random engine with mxnet's inited seed. Adding a 
separate seed here is can help users to tweak the randomness of this particular 
sampler. The idea is,
   > 
   > * if you change seed for mxnet, it affects all internal random generator, 
including this sampler.
   > * If tweaking the seed specified here, only this sampler is affected.
   
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to