tkonolige commented on a change in pull request #8817:
URL: https://github.com/apache/tvm/pull/8817#discussion_r694383996



##########
File path: src/tir/schedule/primitive.h
##########
@@ -19,12 +19,26 @@
 #ifndef TVM_TIR_SCHEDULE_PRIMITIVE_H_
 #define TVM_TIR_SCHEDULE_PRIMITIVE_H_
 
+#include <tvm/support/random_engine.h>
 #include <tvm/tir/schedule/state.h>
 
 namespace tvm {
 namespace tir {
 
 /******** Schedule: Sampling ********/
+/*!
+ * \brief Sample once category from candidates according to the probability 
weights.
+ * \param self The schedule to update
+ * \param rand_state The pointer to schedule's random state
+ * \param candidates The candidates
+ * \param probs The probability distribution of the candidates
+ * \param decision The sampling decision, if any
+ * \return The random variable sampled from candidates
+ */
+TVM_DLL int64_t 
SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state,
+                                  const Array<Integer>& candidates, const 
Array<FloatImm>& probs,
+                                  Optional<Integer>* decision);

Review comment:
       I'm a little confused as to why this `SampleCategorical` modifies 
decision while all the others do not. Could you explain?

##########
File path: src/tir/schedule/concrete_schedule.cc
##########
@@ -208,6 +211,25 @@ Schedule ConcreteScheduleNode::Copy() const {
   }
 
 /******** Schedule: Schedule: Sampling ********/
+
+void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState 
seed) {
+  support::LinearCongruentialEngine(&rand_state_).Seed(seed == -1 ? 
std::random_device()() : seed);
+}
+support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() 
{
+  // In order for reproducibility, we computer the new seed using RNG's random 
state and a different
+  // set of parameters. Note that both 32767 and 1999999973 are prime numbers.
+  return (support::LinearCongruentialEngine(&rand_state_)() * 32767) % 
1999999973;
+}

Review comment:
       It seems like ForkSeed is analogous to what is called "splitting" in the 
random number generator literature. I'm not quite an expert on this, but I did 
do a bit of research into PRNGS for the Threefry implementation we have. 
Everything I read says that there are no proofs to the validity of splitting 
LCGs (is the method you use here from a paper?). The paper ["Splittable 
Pseudorandom Number Generators using Cryptographic 
Hashing"](https://publications.lib.chalmers.se/records/fulltext/183348/local_183348.pdf)
 provides some good explanations.
   
   In practice, I expect we will see some issues. If this function somehow 
perfectly bisects the space of random numbers generated by this PRNG, then we 
could expect to start seeing repeats of previous random numbers after 31 
splits. Given that this splitting does not perfectly bisect the space, I'd 
assume that we start seeing repeats much sooner. Repeating portions of the 
search space may mean that we may no be able to visit the entire search space 
during tuning or that we may bias results towards a certain section of the 
space.
   
   I'd suggest we adopt a splittable PRNG here as that appears the be what we 
need. Maybe we can find an existing implementation online as implementing your 
own PRNG can have subtle issues.

##########
File path: src/tir/schedule/primitive/sampling.cc
##########
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <random>
+
+#include "../../../support/array.h"
+#include "../primitive.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* 
rand_state,
+                          const Array<Integer>& candidates, const 
Array<FloatImm>& probs,
+                          Optional<Integer>* decision) {
+  CHECK(candidates.size() == probs.size())
+      << "ValueError: number of candidates does not match number of 
probabilities.";
+  int i = -1;
+  int n = candidates.size();
+
+  if (decision->defined()) {
+    const auto* int_imm = decision->as<IntImmNode>();
+    i = int_imm->value;
+    CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " 
<< n
+                           << ", but decision is: " << i;
+  } else {
+    std::vector<double> weights = support::AsVector<FloatImm, double>(probs);
+    std::discrete_distribution<int> dist(weights.begin(), weights.end());
+    support::LinearCongruentialEngine rand_(rand_state);
+    i = dist(rand_);
+    ICHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " 
<< n
+                            << ", but decision is: " << i;

Review comment:
       You can move this check out of the if.

##########
File path: src/tir/schedule/traced_schedule.h
##########
@@ -47,6 +47,15 @@ class TracedScheduleNode : public ConcreteScheduleNode {
 
  public:
   /******** Schedule: Sampling ********/
+  /*!
+   * \brief Sample an integer given the probability distribution
+   * \param candidates The candidates
+   * \param probs The probability distribution of the candidates
+   * \param decision The sampling decision

Review comment:
       Can you add documentation for `decision` saying what happens if it is 
set or not.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to