Repository: spark
Updated Branches:
  refs/heads/master c8388297c -> 68c0c460b


[SPARK-13742] [CORE] Add non-iterator interface to RandomSampler

JIRA: https://issues.apache.org/jira/browse/SPARK-13742

## What changes were proposed in this pull request?

`RandomSampler.sample` currently accepts iterator as input and output another 
iterator. This makes it inappropriate to use in wholestage codegen of `Sampler` 
operator #11517. This change is to add non-iterator interface to 
`RandomSampler`.

This change adds a new method `def sample(): Int` to the trait `RandomSampler`. 
As we don't need to know the actual values of the sampling items, so this new 
method takes no arguments.

This method will decide whether to sample the next item or not. It returns how 
many times the next item will be sampled.

For `BernoulliSampler` and `BernoulliCellSampler`, the returned sampling times 
can only be 0 or 1. It simply means whether to sample the next item or not.

For `PoissonSampler`, the returned value can be more than 1, meaning the next 
item will be sampled multiple times.

## How was this patch tested?

Tests are added into `RandomSamplerSuite`.

Author: Liang-Chi Hsieh <[email protected]>
Author: Liang-Chi Hsieh <[email protected]>
Author: Liang-Chi Hsieh <[email protected]>

Closes #11578 from viirya/random-sampler-no-iterator.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/68c0c460
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/68c0c460
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/68c0c460

Branch: refs/heads/master
Commit: 68c0c460bfc51d7f69d09b613c49c212dd0b375c
Parents: c838829
Author: Liang-Chi Hsieh <[email protected]>
Authored: Mon Mar 28 09:58:47 2016 -0700
Committer: Davies Liu <[email protected]>
Committed: Mon Mar 28 09:58:47 2016 -0700

----------------------------------------------------------------------
 .../spark/util/random/RandomSampler.scala       | 201 +++++++++----------
 .../rdd/PartitionwiseSampledRDDSuite.scala      |   2 +
 .../spark/util/random/RandomSamplerSuite.scala  | 197 ++++++++++++++++++
 3 files changed, 289 insertions(+), 111 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/68c0c460/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala 
b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 3c61528..2921b93 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -39,7 +39,14 @@ import org.apache.spark.annotation.DeveloperApi
 trait RandomSampler[T, U] extends Pseudorandom with Cloneable with 
Serializable {
 
   /** take a random sample */
-  def sample(items: Iterator[T]): Iterator[U]
+  def sample(items: Iterator[T]): Iterator[U] =
+    items.filter(_ => sample > 0).asInstanceOf[Iterator[U]]
+
+  /**
+   * Whether to sample the next item or not.
+   * Return how many times the next item will be sampled. Return 0 if it is 
not sampled.
+   */
+  def sample(): Int
 
   /** return a copy of the RandomSampler object */
   override def clone: RandomSampler[T, U] =
@@ -107,21 +114,13 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, 
complement: Boolean = fals
 
   override def setSeed(seed: Long): Unit = rng.setSeed(seed)
 
-  override def sample(items: Iterator[T]): Iterator[T] = {
+  override def sample(): Int = {
     if (ub - lb <= 0.0) {
-      if (complement) items else Iterator.empty
+      if (complement) 1 else 0
     } else {
-      if (complement) {
-        items.filter { item => {
-          val x = rng.nextDouble()
-          (x < lb) || (x >= ub)
-        }}
-      } else {
-        items.filter { item => {
-          val x = rng.nextDouble()
-          (x >= lb) && (x < ub)
-        }}
-      }
+      val x = rng.nextDouble()
+      val n = if ((x >= lb) && (x < ub)) 1 else 0
+      if (complement) 1 - n else n
     }
   }
 
@@ -155,15 +154,22 @@ class BernoulliSampler[T: ClassTag](fraction: Double) 
extends RandomSampler[T, T
 
   override def setSeed(seed: Long): Unit = rng.setSeed(seed)
 
-  override def sample(items: Iterator[T]): Iterator[T] = {
+  private lazy val gapSampling: GapSampling =
+    new GapSampling(fraction, rng, RandomSampler.rngEpsilon)
+
+  override def sample(): Int = {
     if (fraction <= 0.0) {
-      Iterator.empty
+      0
     } else if (fraction >= 1.0) {
-      items
+      1
     } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
-      new GapSamplingIterator(items, fraction, rng, RandomSampler.rngEpsilon)
+      gapSampling.sample()
     } else {
-      items.filter { _ => rng.nextDouble() <= fraction }
+      if (rng.nextDouble() <= fraction) {
+        1
+      } else {
+        0
+      }
     }
   }
 
@@ -201,15 +207,29 @@ class PoissonSampler[T: ClassTag](
     rngGap.setSeed(seed)
   }
 
-  override def sample(items: Iterator[T]): Iterator[T] = {
+  private lazy val gapSamplingReplacement =
+    new GapSamplingReplacement(fraction, rngGap, RandomSampler.rngEpsilon)
+
+  override def sample(): Int = {
     if (fraction <= 0.0) {
-      Iterator.empty
+      0
     } else if (useGapSamplingIfPossible &&
                fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
-      new GapSamplingReplacementIterator(items, fraction, rngGap, 
RandomSampler.rngEpsilon)
+      gapSamplingReplacement.sample()
+    } else {
+      rng.sample()
+    }
+  }
+
+  override def sample(items: Iterator[T]): Iterator[T] = {
+    if (fraction <= 0.0) {
+      Iterator.empty
     } else {
+      val useGapSampling = useGapSamplingIfPossible &&
+        fraction <= RandomSampler.defaultMaxGapSamplingFraction
+
       items.flatMap { item =>
-        val count = rng.sample()
+        val count = if (useGapSampling) gapSamplingReplacement.sample() else 
rng.sample()
         if (count == 0) Iterator.empty else Iterator.fill(count)(item)
       }
     }
@@ -220,50 +240,36 @@ class PoissonSampler[T: ClassTag](
 
 
 private[spark]
-class GapSamplingIterator[T: ClassTag](
-    var data: Iterator[T],
+class GapSampling(
     f: Double,
     rng: Random = RandomSampler.newDefaultRNG,
-    epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
+    epsilon: Double = RandomSampler.rngEpsilon) extends Serializable {
 
   require(f > 0.0  &&  f < 1.0, s"Sampling fraction ($f) must reside on open 
interval (0, 1)")
   require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
 
-  /** implement efficient linear-sequence drop until Scala includes fix for 
jira SI-8835. */
-  private val iterDrop: Int => Unit = {
-    val arrayClass = Array.empty[T].iterator.getClass
-    val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
-    data.getClass match {
-      case `arrayClass` =>
-        (n: Int) => { data = data.drop(n) }
-      case `arrayBufferClass` =>
-        (n: Int) => { data = data.drop(n) }
-      case _ =>
-        (n: Int) => {
-          var j = 0
-          while (j < n && data.hasNext) {
-            data.next()
-            j += 1
-          }
-        }
-    }
-  }
-
-  override def hasNext: Boolean = data.hasNext
+  private val lnq = math.log1p(-f)
 
-  override def next(): T = {
-    val r = data.next()
-    advance()
-    r
+  /** Return 1 if the next item should be sampled. Otherwise, return 0. */
+  def sample(): Int = {
+    if (countForDropping > 0) {
+      countForDropping -= 1
+      0
+    } else {
+      advance()
+      1
+    }
   }
 
-  private val lnq = math.log1p(-f)
+  private var countForDropping: Int = 0
 
-  /** skip elements that won't be sampled, according to geometric dist P(k) = 
(f)(1-f)^k. */
+  /**
+   * Decide the number of elements that won't be sampled,
+   * according to geometric dist P(k) = (f)(1-f)^k.
+   */
   private def advance(): Unit = {
     val u = math.max(rng.nextDouble(), epsilon)
-    val k = (math.log(u) / lnq).toInt
-    iterDrop(k)
+    countForDropping = (math.log(u) / lnq).toInt
   }
 
   /** advance to first sample as part of object construction. */
@@ -273,73 +279,24 @@ class GapSamplingIterator[T: ClassTag](
   // work reliably.
 }
 
+
 private[spark]
-class GapSamplingReplacementIterator[T: ClassTag](
-    var data: Iterator[T],
-    f: Double,
-    rng: Random = RandomSampler.newDefaultRNG,
-    epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
+class GapSamplingReplacement(
+    val f: Double,
+    val rng: Random = RandomSampler.newDefaultRNG,
+    epsilon: Double = RandomSampler.rngEpsilon) extends Serializable {
 
   require(f > 0.0, s"Sampling fraction ($f) must be > 0")
   require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
 
-  /** implement efficient linear-sequence drop until scala includes fix for 
jira SI-8835. */
-  private val iterDrop: Int => Unit = {
-    val arrayClass = Array.empty[T].iterator.getClass
-    val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
-    data.getClass match {
-      case `arrayClass` =>
-        (n: Int) => { data = data.drop(n) }
-      case `arrayBufferClass` =>
-        (n: Int) => { data = data.drop(n) }
-      case _ =>
-        (n: Int) => {
-          var j = 0
-          while (j < n && data.hasNext) {
-            data.next()
-            j += 1
-          }
-        }
-    }
-  }
-
-  /** current sampling value, and its replication factor, as we are sampling 
with replacement. */
-  private var v: T = _
-  private var rep: Int = 0
-
-  override def hasNext: Boolean = data.hasNext || rep > 0
-
-  override def next(): T = {
-    val r = v
-    rep -= 1
-    if (rep <= 0) advance()
-    r
-  }
-
-  /**
-   * Skip elements with replication factor zero (i.e. elements that won't be 
sampled).
-   * Samples 'k' from geometric distribution  P(k) = (1-q)(q)^k, where q = 
e^(-f), that is
-   * q is the probability of Poisson(0; f)
-   */
-  private def advance(): Unit = {
-    val u = math.max(rng.nextDouble(), epsilon)
-    val k = (math.log(u) / (-f)).toInt
-    iterDrop(k)
-    // set the value and replication factor for the next value
-    if (data.hasNext) {
-      v = data.next()
-      rep = poissonGE1
-    }
-  }
-
-  private val q = math.exp(-f)
+  protected val q = math.exp(-f)
 
   /**
    * Sample from Poisson distribution, conditioned such that the sampled value 
is >= 1.
    * This is an adaptation from the algorithm for Generating Poisson 
distributed random variables:
    * http://en.wikipedia.org/wiki/Poisson_distribution
    */
-  private def poissonGE1: Int = {
+  protected def poissonGE1: Int = {
     // simulate that the standard poisson sampling
     // gave us at least one iteration, for a sample of >= 1
     var pp = q + ((1.0 - q) * rng.nextDouble())
@@ -353,6 +310,28 @@ class GapSamplingReplacementIterator[T: ClassTag](
     }
     r
   }
+  private var countForDropping: Int = 0
+
+  def sample(): Int = {
+    if (countForDropping > 0) {
+      countForDropping -= 1
+      0
+    } else {
+      val r = poissonGE1
+      advance()
+      r
+    }
+  }
+
+  /**
+   * Skip elements with replication factor zero (i.e. elements that won't be 
sampled).
+   * Samples 'k' from geometric distribution  P(k) = (1-q)(q)^k, where q = 
e^(-f), that is
+   * q is the probabililty of Poisson(0; f)
+   */
+  private def advance(): Unit = {
+    val u = math.max(rng.nextDouble(), epsilon)
+    countForDropping = (math.log(u) / (-f)).toInt
+  }
 
   /** advance to first sample as part of object construction. */
   advance()

http://git-wip-us.apache.org/repos/asf/spark/blob/68c0c460/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala 
b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
index 132a5fa..cb0de1c 100644
--- 
a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
@@ -29,6 +29,8 @@ class MockSampler extends RandomSampler[Long, Long] {
     s = seed
   }
 
+  override def sample(): Int = 1
+
   override def sample(items: Iterator[Long]): Iterator[Long] = {
     Iterator(s)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/68c0c460/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala 
b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
index 791491d..7eb2f56 100644
--- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
@@ -129,6 +129,13 @@ class RandomSamplerSuite extends SparkFunSuite with 
Matchers {
     t(m / 2)
   }
 
+  def replacementSampling(data: Iterator[Int], sampler: PoissonSampler[Int]): 
Iterator[Int] = {
+    data.flatMap { item =>
+      val count = sampler.sample()
+      if (count == 0) Iterator.empty else Iterator.fill(count)(item)
+    }
+  }
+
   test("utilities") {
     val s1 = Array(0, 1, 1, 0, 2)
     val s2 = Array(1, 0, 3, 2, 1)
@@ -189,6 +196,36 @@ class RandomSamplerSuite extends SparkFunSuite with 
Matchers {
     d should be > D
   }
 
+  test("bernoulli sampling without iterator") {
+    // Tests expect maximum gap sampling fraction to be this value
+    RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+    var d: Double = 0.0
+
+    val data = Iterator.from(0)
+
+    var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.5)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), 
gaps(sample(Iterator.from(0), 0.5)))
+    d should be < D
+
+    sampler = new BernoulliSampler[Int](0.7)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), 
gaps(sample(Iterator.from(0), 0.7)))
+    d should be < D
+
+    sampler = new BernoulliSampler[Int](0.9)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), 
gaps(sample(Iterator.from(0), 0.9)))
+    d should be < D
+
+    // sampling at different frequencies should show up as statistically 
different:
+    sampler = new BernoulliSampler[Int](0.5)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), 
gaps(sample(Iterator.from(0), 0.6)))
+    d should be > D
+  }
+
   test("bernoulli sampling with gap sampling optimization") {
     // Tests expect maximum gap sampling fraction to be this value
     RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
@@ -217,6 +254,37 @@ class RandomSamplerSuite extends SparkFunSuite with 
Matchers {
     d should be > D
   }
 
+  test("bernoulli sampling (without iterator) with gap sampling optimization") 
{
+    // Tests expect maximum gap sampling fraction to be this value
+    RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+    var d: Double = 0.0
+
+    val data = Iterator.from(0)
+
+    var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.01)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)),
+      gaps(sample(Iterator.from(0), 0.01)))
+    d should be < D
+
+    sampler = new BernoulliSampler[Int](0.1)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), 
gaps(sample(Iterator.from(0), 0.1)))
+    d should be < D
+
+    sampler = new BernoulliSampler[Int](0.3)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), 
gaps(sample(Iterator.from(0), 0.3)))
+    d should be < D
+
+    // sampling at different frequencies should show up as statistically 
different:
+    sampler = new BernoulliSampler[Int](0.3)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), 
gaps(sample(Iterator.from(0), 0.4)))
+    d should be > D
+  }
+
   test("bernoulli boundary cases") {
     val data = (1 to 100).toArray
 
@@ -233,6 +301,22 @@ class RandomSamplerSuite extends SparkFunSuite with 
Matchers {
     sampler.sample(data.iterator).toArray should be (data)
   }
 
+  test("bernoulli (without iterator) boundary cases") {
+    val data = (1 to 100).toArray
+
+    var sampler = new BernoulliSampler[Int](0.0)
+    data.filter(_ => sampler.sample() > 0) should be (Array.empty[Int])
+
+    sampler = new BernoulliSampler[Int](1.0)
+    data.filter(_ => sampler.sample() > 0) should be (data)
+
+    sampler = new BernoulliSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 
2.0))
+    data.filter(_ => sampler.sample() > 0) should be (Array.empty[Int])
+
+    sampler = new BernoulliSampler[Int](1.0 + (RandomSampler.roundingEpsilon / 
2.0))
+    data.filter(_ => sampler.sample() > 0) should be (data)
+  }
+
   test("bernoulli data types") {
     // Tests expect maximum gap sampling fraction to be this value
     RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
@@ -341,6 +425,36 @@ class RandomSamplerSuite extends SparkFunSuite with 
Matchers {
     d should be > D
   }
 
+  test("replacement sampling without iterator") {
+    // Tests expect maximum gap sampling fraction to be this value
+    RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+    var d: Double = 0.0
+
+    val data = Iterator.from(0)
+
+    var sampler = new PoissonSampler[Int](0.5)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(replacementSampling(data, sampler)), 
gaps(sampleWR(Iterator.from(0), 0.5)))
+    d should be < D
+
+    sampler = new PoissonSampler[Int](0.7)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(replacementSampling(data, sampler)), 
gaps(sampleWR(Iterator.from(0), 0.7)))
+    d should be < D
+
+    sampler = new PoissonSampler[Int](0.9)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(replacementSampling(data, sampler)), 
gaps(sampleWR(Iterator.from(0), 0.9)))
+    d should be < D
+
+    // sampling at different frequencies should show up as statistically 
different:
+    sampler = new PoissonSampler[Int](0.5)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(replacementSampling(data, sampler)), 
gaps(sampleWR(Iterator.from(0), 0.6)))
+    d should be > D
+  }
+
   test("replacement sampling with gap sampling") {
     // Tests expect maximum gap sampling fraction to be this value
     RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
@@ -369,6 +483,36 @@ class RandomSamplerSuite extends SparkFunSuite with 
Matchers {
     d should be > D
   }
 
+  test("replacement sampling (without iterator) with gap sampling") {
+    // Tests expect maximum gap sampling fraction to be this value
+    RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+    var d: Double = 0.0
+
+    val data = Iterator.from(0)
+
+    var sampler = new PoissonSampler[Int](0.01)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(replacementSampling(data, sampler)), 
gaps(sampleWR(Iterator.from(0), 0.01)))
+    d should be < D
+
+    sampler = new PoissonSampler[Int](0.1)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(replacementSampling(data, sampler)), 
gaps(sampleWR(Iterator.from(0), 0.1)))
+    d should be < D
+
+    sampler = new PoissonSampler[Int](0.3)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(replacementSampling(data, sampler)), 
gaps(sampleWR(Iterator.from(0), 0.3)))
+    d should be < D
+
+    // sampling at different frequencies should show up as statistically 
different:
+    sampler = new PoissonSampler[Int](0.3)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(replacementSampling(data, sampler)), 
gaps(sampleWR(Iterator.from(0), 0.4)))
+    d should be > D
+  }
+
   test("replacement boundary cases") {
     val data = (1 to 100).toArray
 
@@ -383,6 +527,20 @@ class RandomSamplerSuite extends SparkFunSuite with 
Matchers {
     sampler.sample(data.iterator).length should be > (data.length)
   }
 
+  test("replacement (without) boundary cases") {
+    val data = (1 to 100).toArray
+
+    var sampler = new PoissonSampler[Int](0.0)
+    replacementSampling(data.iterator, sampler).toArray should be 
(Array.empty[Int])
+
+    sampler = new PoissonSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 
2.0))
+    replacementSampling(data.iterator, sampler).toArray should be 
(Array.empty[Int])
+
+    // sampling with replacement has no upper bound on sampling fraction
+    sampler = new PoissonSampler[Int](2.0)
+    replacementSampling(data.iterator, sampler).length should be > 
(data.length)
+  }
+
   test("replacement data types") {
     // Tests expect maximum gap sampling fraction to be this value
     RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
@@ -477,6 +635,22 @@ class RandomSamplerSuite extends SparkFunSuite with 
Matchers {
     d should be < D
   }
 
+  test("bernoulli partitioning sampling without iterator") {
+    var d: Double = 0.0
+
+    val data = Iterator.from(0)
+
+    var sampler = new BernoulliCellSampler[Int](0.1, 0.2)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), 
gaps(sample(Iterator.from(0), 0.1)))
+    d should be < D
+
+    sampler = new BernoulliCellSampler[Int](0.1, 0.2, true)
+    sampler.setSeed(rngSeed.nextLong)
+    d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), 
gaps(sample(Iterator.from(0), 0.9)))
+    d should be < D
+  }
+
   test("bernoulli partitioning boundary cases") {
     val data = (1 to 100).toArray
     val d = RandomSampler.roundingEpsilon / 2.0
@@ -500,6 +674,29 @@ class RandomSamplerSuite extends SparkFunSuite with 
Matchers {
     sampler.sample(data.iterator).toArray should be (Array.empty[Int])
   }
 
+  test("bernoulli partitioning (without iterator) boundary cases") {
+    val data = (1 to 100).toArray
+    val d = RandomSampler.roundingEpsilon / 2.0
+
+    var sampler = new BernoulliCellSampler[Int](0.0, 0.0)
+    data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int])
+
+    sampler = new BernoulliCellSampler[Int](0.5, 0.5)
+    data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int])
+
+    sampler = new BernoulliCellSampler[Int](1.0, 1.0)
+    data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int])
+
+    sampler = new BernoulliCellSampler[Int](0.0, 1.0)
+    data.filter(_ => sampler.sample() > 0).toArray should be (data)
+
+    sampler = new BernoulliCellSampler[Int](0.0 - d, 1.0 + d)
+    data.filter(_ => sampler.sample() > 0).toArray should be (data)
+
+    sampler = new BernoulliCellSampler[Int](0.5, 0.5 - d)
+    data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int])
+  }
+
   test("bernoulli partitioning data") {
     val seed = rngSeed.nextLong
     val data = (1 to 100).toArray


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to