Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/2455#discussion_r17769391
--- Diff:
core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala ---
@@ -43,66 +46,218 @@ trait RandomSampler[T, U] extends Pseudorandom with
Cloneable with Serializable
throw new NotImplementedError("clone() is not implemented.")
}
+
+object RandomSampler {
+ // Default random number generator used by random samplers
+ def rngDefault: Random = new XORShiftRandom
+
+ // Default gap sampling maximum
+ // For sampling fractions <= this value, the gap sampling optimization
will be applied.
+ // Above this value, it is assumed that "tradtional" bernoulli sampling
is faster. The
+ // optimal value for this will depend on the RNG. More expensive RNGs
will tend to make
+ // the optimal value higher. The most reliable way to determine this
value for a given RNG
+ // is to experiment. I would expect a value of 0.5 to be close in most
cases.
+ def gsmDefault: Double = 0.4
+
+ // Default gap sampling epsilon
+ // When sampling random floating point values the gap sampling logic
requires value > 0. An
+ // optimal value for this parameter is at or near the minimum positive
floating point value
+ // returned by nextDouble() for the RNG being used.
+ def epsDefault: Double = 5e-11
+}
+
+
/**
* :: DeveloperApi ::
* A sampler based on Bernoulli trials.
*
- * @param lb lower bound of the acceptance range
- * @param ub upper bound of the acceptance range
- * @param complement whether to use the complement of the range specified,
default to false
+ * @param fraction the sampling fraction, aka Bernoulli sampling
probability
* @tparam T item type
*/
@DeveloperApi
-class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean =
false)
- extends RandomSampler[T, T] {
+class BernoulliSampler[T: ClassTag](fraction: Double) extends
RandomSampler[T, T] {
- private[random] var rng: Random = new XORShiftRandom
+ require(fraction >= 0.0 && fraction <= 1.0, "Sampling fraction must be
on interval [0, 1]")
- def this(ratio: Double) = this(0.0d, ratio)
+ def this(lb: Double, ub: Double, complement: Boolean = false) =
+ this(if (complement) (1.0 - (ub - lb)) else (ub - lb))
+
+ private val rng: Random = RandomSampler.rngDefault
override def setSeed(seed: Long) = rng.setSeed(seed)
override def sample(items: Iterator[T]): Iterator[T] = {
- items.filter { item =>
- val x = rng.nextDouble()
- (x >= lb && x < ub) ^ complement
+ fraction match {
+ case f if (f <= 0.0) => Iterator.empty
+ case f if (f >= 1.0) => items
+ case f if (f <= RandomSampler.gsmDefault) =>
+ new GapSamplingIterator(items, f, rng, RandomSampler.epsDefault)
+ case _ => items.filter(_ => (rng.nextDouble() <= fraction))
--- End diff --
Did you test whether `rdd.randomSplit()` will produce non-overlapping
subsets with this change?
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]