tkonolige commented on a change in pull request #8817:
URL: https://github.com/apache/tvm/pull/8817#discussion_r694383996
##########
File path: src/tir/schedule/primitive.h
##########
@@ -19,12 +19,26 @@
#ifndef TVM_TIR_SCHEDULE_PRIMITIVE_H_
#define TVM_TIR_SCHEDULE_PRIMITIVE_H_
+#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/state.h>
namespace tvm {
namespace tir {
/******** Schedule: Sampling ********/
+/*!
+ * \brief Sample once category from candidates according to the probability
weights.
+ * \param self The schedule to update
+ * \param rand_state The pointer to schedule's random state
+ * \param candidates The candidates
+ * \param probs The probability distribution of the candidates
+ * \param decision The sampling decision, if any
+ * \return The random variable sampled from candidates
+ */
+TVM_DLL int64_t
SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state,
+ const Array<Integer>& candidates, const
Array<FloatImm>& probs,
+ Optional<Integer>* decision);
Review comment:
I'm a little confused as to why this `SampleCategorical` modifies
decision while all the others do not. Could you explain?
##########
File path: src/tir/schedule/concrete_schedule.cc
##########
@@ -208,6 +211,25 @@ Schedule ConcreteScheduleNode::Copy() const {
}
/******** Schedule: Schedule: Sampling ********/
+
+void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState
seed) {
+ support::LinearCongruentialEngine(&rand_state_).Seed(seed == -1 ?
std::random_device()() : seed);
+}
+support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed()
{
+ // In order for reproducibility, we computer the new seed using RNG's random
state and a different
+ // set of parameters. Note that both 32767 and 1999999973 are prime numbers.
+ return (support::LinearCongruentialEngine(&rand_state_)() * 32767) %
1999999973;
+}
Review comment:
It seems like ForkSeed is analogous to what is called "splitting" in the
random number generator literature. I'm not quite an expert on this, but I did
do a bit of research into PRNGS for the Threefry implementation we have.
Everything I read says that there are no proofs to the validity of splitting
LCGs (is the method you use here from a paper?). The paper ["Splittable
Pseudorandom Number Generators using Cryptographic
Hashing"](https://publications.lib.chalmers.se/records/fulltext/183348/local_183348.pdf)
provides some good explanations.
In practice, I expect we will see some issues. If this function somehow
perfectly bisects the space of random numbers generated by this PRNG, then we
could expect to start seeing repeats of previous random numbers after 31
splits. Given that this splitting does not perfectly bisect the space, I'd
assume that we start seeing repeats much sooner. Repeating portions of the
search space may mean that we may no be able to visit the entire search space
during tuning or that we may bias results towards a certain section of the
space.
I'd suggest we adopt a splittable PRNG here as that appears the be what we
need. Maybe we can find an existing implementation online as implementing your
own PRNG can have subtle issues.
##########
File path: src/tir/schedule/primitive/sampling.cc
##########
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <random>
+
+#include "../../../support/array.h"
+#include "../primitive.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace tir {
+
+int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState*
rand_state,
+ const Array<Integer>& candidates, const
Array<FloatImm>& probs,
+ Optional<Integer>* decision) {
+ CHECK(candidates.size() == probs.size())
+ << "ValueError: number of candidates does not match number of
probabilities.";
+ int i = -1;
+ int n = candidates.size();
+
+ if (decision->defined()) {
+ const auto* int_imm = decision->as<IntImmNode>();
+ i = int_imm->value;
+ CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = "
<< n
+ << ", but decision is: " << i;
+ } else {
+ std::vector<double> weights = support::AsVector<FloatImm, double>(probs);
+ std::discrete_distribution<int> dist(weights.begin(), weights.end());
+ support::LinearCongruentialEngine rand_(rand_state);
+ i = dist(rand_);
+ ICHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = "
<< n
+ << ", but decision is: " << i;
Review comment:
You can move this check out of the if.
##########
File path: src/tir/schedule/traced_schedule.h
##########
@@ -47,6 +47,15 @@ class TracedScheduleNode : public ConcreteScheduleNode {
public:
/******** Schedule: Sampling ********/
+ /*!
+ * \brief Sample an integer given the probability distribution
+ * \param candidates The candidates
+ * \param probs The probability distribution of the candidates
+ * \param decision The sampling decision
Review comment:
Can you add documentation for `decision` saying what happens if it is
set or not.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]