Repository: spark
Updated Branches:
  refs/heads/master 28dcbb531 -> b715aa0c8


[SPARK-2937] Separate out samplyByKeyExact as its own API in PairRDDFunction

To enable Python consistency and `Experimental` label of the `sampleByKeyExact` 
API.

Author: Doris Xin <[email protected]>
Author: Xiangrui Meng <[email protected]>

Closes #1866 from dorx/stratified and squashes the following commits:

0ad97b2 [Doris Xin] reviewer comments.
2948aae [Doris Xin] remove unrelated changes
e990325 [Doris Xin] Merge branch 'master' into stratified
555a3f9 [Doris Xin] separate out sampleByKeyExact as its own API
616e55c [Doris Xin] merge master
245439e [Doris Xin] moved minSamplingRate to getUpperBound
eaf5771 [Doris Xin] bug fixes.
17a381b [Doris Xin] fixed a merge issue and a failed unit
ea7d27f [Doris Xin] merge master
b223529 [Xiangrui Meng] use approx bounds for poisson fix poisson mean for 
waitlisting add unit tests for Java
b3013a4 [Xiangrui Meng] move math3 back to test scope
eecee5f [Doris Xin] Merge branch 'master' into stratified
f4c21f3 [Doris Xin] Reviewer comments
a10e68d [Doris Xin] style fix
a2bf756 [Doris Xin] Merge branch 'master' into stratified
680b677 [Doris Xin] use mapPartitionWithIndex instead
9884a9f [Doris Xin] style fix
bbfb8c9 [Doris Xin] Merge branch 'master' into stratified
ee9d260 [Doris Xin] addressed reviewer comments
6b5b10b [Doris Xin] Merge branch 'master' into stratified
254e03c [Doris Xin] minor fixes and Java API.
4ad516b [Doris Xin] remove unused imports from PairRDDFunctions
bd9dc6e [Doris Xin] unit bug and style violation fixed
1fe1cff [Doris Xin] Changed fractionByKey to a map to enable arg check
944a10c [Doris Xin] [SPARK-2145] Add lower bound on sampling rate
0214a76 [Doris Xin] cleanUp
90d94c0 [Doris Xin] merge master
9e74ab5 [Doris Xin] Separated out most of the logic in sampleByKey
7327611 [Doris Xin] merge master
50581fc [Doris Xin] added a TODO for logging in python
46f6c8c [Doris Xin] fixed the NPE caused by closures being cleaned before being 
passed into the aggregate function
7e1a481 [Doris Xin] changed the permission on SamplingUtil
1d413ce [Doris Xin] fixed checkstyle issues
9ee94ee [Doris Xin] [SPARK-2082] stratified sampling in PairRDDFunctions that 
guarantees exact sample size
e3fd6a6 [Doris Xin] Merge branch 'master' into takeSample
7cab53a [Doris Xin] fixed import bug in rdd.py
ffea61a [Doris Xin] SPARK-1939: Refactor takeSample method in RDD
1441977 [Doris Xin] SPARK-1939 Refactor takeSample method in RDD to use ScaSRS


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b715aa0c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b715aa0c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b715aa0c

Branch: refs/heads/master
Commit: b715aa0c8090cd57158ead2a1b35632cb98a6277
Parents: 28dcbb5
Author: Doris Xin <[email protected]>
Authored: Sun Aug 10 16:31:07 2014 -0700
Committer: Xiangrui Meng <[email protected]>
Committed: Sun Aug 10 16:31:07 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/api/java/JavaPairRDD.scala |  68 +++---
 .../org/apache/spark/rdd/PairRDDFunctions.scala |  51 +++--
 .../java/org/apache/spark/JavaAPISuite.java     |  20 +-
 .../spark/rdd/PairRDDFunctionsSuite.scala       | 205 ++++++++++++-------
 4 files changed, 216 insertions(+), 128 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b715aa0c/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala 
b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 76d4193..feeb6c0 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -133,68 +133,62 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
    * Return a subset of this RDD sampled by key (via stratified sampling).
    *
    * Create a sample of this RDD using variable sampling rates for different 
keys as specified by
-   * `fractions`, a key to sampling rate map.
-   *
-   * If `exact` is set to false, create the sample via simple random sampling, 
with one pass
-   * over the RDD, to produce a sample of size that's approximately equal to 
the sum of
-   * math.ceil(numItems * samplingRate) over all key values; otherwise, use 
additional passes over
-   * the RDD to create a sample size that's exactly equal to the sum of
+   * `fractions`, a key to sampling rate map, via simple random sampling with 
one pass over the
+   * RDD, to produce a sample of size that's approximately equal to the sum of
    * math.ceil(numItems * samplingRate) over all key values.
    */
   def sampleByKey(withReplacement: Boolean,
       fractions: JMap[K, Double],
-      exact: Boolean,
       seed: Long): JavaPairRDD[K, V] =
-    new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, 
seed))
+    new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, seed))
 
   /**
    * Return a subset of this RDD sampled by key (via stratified sampling).
    *
    * Create a sample of this RDD using variable sampling rates for different 
keys as specified by
-   * `fractions`, a key to sampling rate map.
-   *
-   * If `exact` is set to false, create the sample via simple random sampling, 
with one pass
-   * over the RDD, to produce a sample of size that's approximately equal to 
the sum of
-   * math.ceil(numItems * samplingRate) over all key values; otherwise, use 
additional passes over
-   * the RDD to create a sample size that's exactly equal to the sum of
+   * `fractions`, a key to sampling rate map, via simple random sampling with 
one pass over the
+   * RDD, to produce a sample of size that's approximately equal to the sum of
    * math.ceil(numItems * samplingRate) over all key values.
    *
-   * Use Utils.random.nextLong as the default seed for the random number 
generator
+   * Use Utils.random.nextLong as the default seed for the random number 
generator.
    */
   def sampleByKey(withReplacement: Boolean,
-      fractions: JMap[K, Double],
-      exact: Boolean): JavaPairRDD[K, V] =
-    sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong)
+      fractions: JMap[K, Double]): JavaPairRDD[K, V] =
+    sampleByKey(withReplacement, fractions, Utils.random.nextLong)
 
   /**
-   * Return a subset of this RDD sampled by key (via stratified sampling).
-   *
-   * Create a sample of this RDD using variable sampling rates for different 
keys as specified by
-   * `fractions`, a key to sampling rate map.
+   * ::Experimental::
+   * Return a subset of this RDD sampled by key (via stratified sampling) 
containing exactly
+   * math.ceil(numItems * samplingRate) for each stratum (group of pairs with 
the same key).
    *
-   * Produce a sample of size that's approximately equal to the sum of
-   * math.ceil(numItems * samplingRate) over all key values with one pass over 
the RDD via
-   * simple random sampling.
+   * This method differs from [[sampleByKey]] in that we make additional 
passes over the RDD to
+   * create a sample size that's exactly equal to the sum of 
math.ceil(numItems * samplingRate)
+   * over all key values with a 99.99% confidence. When sampling without 
replacement, we need one
+   * additional pass over the RDD to guarantee sample size; when sampling with 
replacement, we need
+   * two additional passes.
    */
-  def sampleByKey(withReplacement: Boolean,
+  @Experimental
+  def sampleByKeyExact(withReplacement: Boolean,
       fractions: JMap[K, Double],
       seed: Long): JavaPairRDD[K, V] =
-    sampleByKey(withReplacement, fractions, false, seed)
+    new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions, 
seed))
 
   /**
-   * Return a subset of this RDD sampled by key (via stratified sampling).
+   * ::Experimental::
+   * Return a subset of this RDD sampled by key (via stratified sampling) 
containing exactly
+   * math.ceil(numItems * samplingRate) for each stratum (group of pairs with 
the same key).
    *
-   * Create a sample of this RDD using variable sampling rates for different 
keys as specified by
-   * `fractions`, a key to sampling rate map.
-   *
-   * Produce a sample of size that's approximately equal to the sum of
-   * math.ceil(numItems * samplingRate) over all key values with one pass over 
the RDD via
-   * simple random sampling.
+   * This method differs from [[sampleByKey]] in that we make additional 
passes over the RDD to
+   * create a sample size that's exactly equal to the sum of 
math.ceil(numItems * samplingRate)
+   * over all key values with a 99.99% confidence. When sampling without 
replacement, we need one
+   * additional pass over the RDD to guarantee sample size; when sampling with 
replacement, we need
+   * two additional passes.
    *
-   * Use Utils.random.nextLong as the default seed for the random number 
generator
+   * Use Utils.random.nextLong as the default seed for the random number 
generator.
    */
-  def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): 
JavaPairRDD[K, V] =
-    sampleByKey(withReplacement, fractions, false, Utils.random.nextLong)
+  @Experimental
+  def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): 
JavaPairRDD[K, V] =
+    sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong)
 
   /**
    * Return the union of this RDD and another one. Any identical elements will 
appear multiple

http://git-wip-us.apache.org/repos/asf/spark/blob/b715aa0c/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala 
b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 5dd6472..f6d9d12 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -197,33 +197,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
    * Return a subset of this RDD sampled by key (via stratified sampling).
    *
    * Create a sample of this RDD using variable sampling rates for different 
keys as specified by
-   * `fractions`, a key to sampling rate map.
-   *
-   * If `exact` is set to false, create the sample via simple random sampling, 
with one pass
-   * over the RDD, to produce a sample of size that's approximately equal to 
the sum of
-   * math.ceil(numItems * samplingRate) over all key values; otherwise, use
-   * additional passes over the RDD to create a sample size that's exactly 
equal to the sum of
-   * math.ceil(numItems * samplingRate) over all key values with a 99.99% 
confidence. When sampling
-   * without replacement, we need one additional pass over the RDD to 
guarantee sample size;
-   * when sampling with replacement, we need two additional passes.
+   * `fractions`, a key to sampling rate map, via simple random sampling with 
one pass over the
+   * RDD, to produce a sample of size that's approximately equal to the sum of
+   * math.ceil(numItems * samplingRate) over all key values.
    *
    * @param withReplacement whether to sample with or without replacement
    * @param fractions map of specific keys to sampling rates
    * @param seed seed for the random number generator
-   * @param exact whether sample size needs to be exactly math.ceil(fraction * 
size) per key
    * @return RDD containing the sampled subset
    */
   def sampleByKey(withReplacement: Boolean,
       fractions: Map[K, Double],
-      exact: Boolean = false,
-      seed: Long = Utils.random.nextLong): RDD[(K, V)]= {
+      seed: Long = Utils.random.nextLong): RDD[(K, V)] = {
+
+    require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")
+
+    val samplingFunc = if (withReplacement) {
+      StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, 
false, seed)
+    } else {
+      StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, 
false, seed)
+    }
+    self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
+  }
+
+  /**
+   * ::Experimental::
+   * Return a subset of this RDD sampled by key (via stratified sampling) 
containing exactly
+   * math.ceil(numItems * samplingRate) for each stratum (group of pairs with 
the same key).
+   *
+   * This method differs from [[sampleByKey]] in that we make additional 
passes over the RDD to
+   * create a sample size that's exactly equal to the sum of 
math.ceil(numItems * samplingRate)
+   * over all key values with a 99.99% confidence. When sampling without 
replacement, we need one
+   * additional pass over the RDD to guarantee sample size; when sampling with 
replacement, we need
+   * two additional passes.
+   *
+   * @param withReplacement whether to sample with or without replacement
+   * @param fractions map of specific keys to sampling rates
+   * @param seed seed for the random number generator
+   * @return RDD containing the sampled subset
+   */
+  @Experimental
+  def sampleByKeyExact(withReplacement: Boolean,
+      fractions: Map[K, Double],
+      seed: Long = Utils.random.nextLong): RDD[(K, V)] = {
 
     require(fractions.values.forall(v => v >= 0.0), "Negative sampling rates.")
 
     val samplingFunc = if (withReplacement) {
-      StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, 
exact, seed)
+      StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions, 
true, seed)
     } else {
-      StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, 
exact, seed)
+      StratifiedSamplingUtils.getBernoulliSamplingFunction(self, fractions, 
true, seed)
     }
     self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b715aa0c/core/src/test/java/org/apache/spark/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java 
b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 56150ca..e1c13de 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -1239,12 +1239,28 @@ public class JavaAPISuite implements Serializable {
     Assert.assertTrue(worCounts.size() == 2);
     Assert.assertTrue(worCounts.get(0) > 0);
     Assert.assertTrue(worCounts.get(1) > 0);
-    JavaPairRDD<Integer, Integer> wrExact = rdd2.sampleByKey(true, fractions, 
true, 1L);
+  }
+
+  @Test
+  @SuppressWarnings("unchecked")
+  public void sampleByKeyExact() {
+    JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 
8), 3);
+    JavaPairRDD<Integer, Integer> rdd2 = rdd1.mapToPair(
+      new PairFunction<Integer, Integer, Integer>() {
+          @Override
+          public Tuple2<Integer, Integer> call(Integer i) {
+              return new Tuple2<Integer, Integer>(i % 2, 1);
+          }
+      });
+    Map<Integer, Object> fractions = Maps.newHashMap();
+    fractions.put(0, 0.5);
+    fractions.put(1, 1.0);
+    JavaPairRDD<Integer, Integer> wrExact = rdd2.sampleByKeyExact(true, 
fractions, 1L);
     Map<Integer, Long> wrExactCounts = (Map<Integer, Long>) (Object) 
wrExact.countByKey();
     Assert.assertTrue(wrExactCounts.size() == 2);
     Assert.assertTrue(wrExactCounts.get(0) == 2);
     Assert.assertTrue(wrExactCounts.get(1) == 4);
-    JavaPairRDD<Integer, Integer> worExact = rdd2.sampleByKey(false, 
fractions, true, 1L);
+    JavaPairRDD<Integer, Integer> worExact = rdd2.sampleByKeyExact(false, 
fractions, 1L);
     Map<Integer, Long> worExactCounts = (Map<Integer, Long>) (Object) 
worExact.countByKey();
     Assert.assertTrue(worExactCounts.size() == 2);
     Assert.assertTrue(worExactCounts.get(0) == 2);

http://git-wip-us.apache.org/repos/asf/spark/blob/b715aa0c/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala 
b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 4f49d4a..63d3ddb 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -84,118 +84,81 @@ class PairRDDFunctionsSuite extends FunSuite with 
SharedSparkContext {
   }
 
   test("sampleByKey") {
-    def stratifier (fractionPositive: Double) = {
-      (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0"
-    }
 
-    def checkSize(exact: Boolean,
-        withReplacement: Boolean,
-        expected: Long,
-        actual: Long,
-        p: Double): Boolean = {
-      if (exact) {
-        return expected == actual
-      }
-      val stdev = if (withReplacement) math.sqrt(expected) else 
math.sqrt(expected * p * (1 - p))
-      // Very forgiving margin since we're dealing with very small sample 
sizes most of the time
-      math.abs(actual - expected) <= 6 * stdev
+    val defaultSeed = 1L
+
+    // vary RDD size
+    for (n <- List(100, 1000, 1000000)) {
+      val data = sc.parallelize(1 to n, 2)
+      val fractionPositive = 0.3
+      val stratifiedData = 
data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
+      val samplingRate = 0.1
+      StratifiedAuxiliary.testSample(stratifiedData, samplingRate, 
defaultSeed, n)
     }
 
-    // Without replacement validation
-    def takeSampleAndValidateBernoulli(stratifiedData: RDD[(String, Int)],
-        exact: Boolean,
-        samplingRate: Double,
-        seed: Long,
-        n: Long) = {
-      val expectedSampleSize = stratifiedData.countByKey()
-        .mapValues(count => math.ceil(count * samplingRate).toInt)
-      val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
-      val sample = stratifiedData.sampleByKey(false, fractions, exact, seed)
-      val sampleCounts = sample.countByKey()
-      val takeSample = sample.collect()
-      sampleCounts.foreach { case(k, v) =>
-        assert(checkSize(exact, false, expectedSampleSize(k), v, 
samplingRate)) }
-      assert(takeSample.size === takeSample.toSet.size)
-      takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not 
in [1, $n]") }
+    // vary fractionPositive
+    for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) {
+      val n = 100
+      val data = sc.parallelize(1 to n, 2)
+      val stratifiedData = 
data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
+      val samplingRate = 0.1
+      StratifiedAuxiliary.testSample(stratifiedData, samplingRate, 
defaultSeed, n)
     }
 
-    // With replacement validation
-    def takeSampleAndValidatePoisson(stratifiedData: RDD[(String, Int)],
-        exact: Boolean,
-        samplingRate: Double,
-        seed: Long,
-        n: Long) = {
-      val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
-        math.ceil(count * samplingRate).toInt)
-      val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
-      val sample = stratifiedData.sampleByKey(true, fractions, exact, seed)
-      val sampleCounts = sample.countByKey()
-      val takeSample = sample.collect()
-      sampleCounts.foreach { case(k, v) =>
-        assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) 
}
-      val groupedByKey = takeSample.groupBy(_._1)
-      for ((key, v) <- groupedByKey) {
-        if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) {
-          // sample large enough for there to be repeats with high likelihood
-          assert(v.toSet.size < expectedSampleSize(key))
-        } else {
-          if (exact) {
-            assert(v.toSet.size <= expectedSampleSize(key))
-          } else {
-            assert(checkSize(false, true, expectedSampleSize(key), 
v.toSet.size, samplingRate))
-          }
-        }
-      }
-      takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not 
in [1, $n]") }
+    // Use the same data for the rest of the tests
+    val fractionPositive = 0.3
+    val n = 100
+    val data = sc.parallelize(1 to n, 2)
+    val stratifiedData = 
data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
+
+    // vary seed
+    for (seed <- defaultSeed to defaultSeed + 5L) {
+      val samplingRate = 0.1
+      StratifiedAuxiliary.testSample(stratifiedData, samplingRate, seed, n)
     }
 
-    def checkAllCombos(stratifiedData: RDD[(String, Int)],
-        samplingRate: Double,
-        seed: Long,
-        n: Long) = {
-      takeSampleAndValidateBernoulli(stratifiedData, true, samplingRate, seed, 
n)
-      takeSampleAndValidateBernoulli(stratifiedData, false, samplingRate, 
seed, n)
-      takeSampleAndValidatePoisson(stratifiedData, true, samplingRate, seed, n)
-      takeSampleAndValidatePoisson(stratifiedData, false, samplingRate, seed, 
n)
+    // vary sampling rate
+    for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) {
+      StratifiedAuxiliary.testSample(stratifiedData, samplingRate, 
defaultSeed, n)
     }
+  }
 
+  test("sampleByKeyExact") {
     val defaultSeed = 1L
 
     // vary RDD size
     for (n <- List(100, 1000, 1000000)) {
       val data = sc.parallelize(1 to n, 2)
       val fractionPositive = 0.3
-      val stratifiedData = data.keyBy(stratifier(fractionPositive))
-
+      val stratifiedData = 
data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
       val samplingRate = 0.1
-      checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
+      StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, 
defaultSeed, n)
     }
 
     // vary fractionPositive
     for (fractionPositive <- List(0.1, 0.3, 0.5, 0.7, 0.9)) {
       val n = 100
       val data = sc.parallelize(1 to n, 2)
-      val stratifiedData = data.keyBy(stratifier(fractionPositive))
-
+      val stratifiedData = 
data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
       val samplingRate = 0.1
-      checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
+      StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, 
defaultSeed, n)
     }
 
     // Use the same data for the rest of the tests
     val fractionPositive = 0.3
     val n = 100
     val data = sc.parallelize(1 to n, 2)
-    val stratifiedData = data.keyBy(stratifier(fractionPositive))
+    val stratifiedData = 
data.keyBy(StratifiedAuxiliary.stratifier(fractionPositive))
 
     // vary seed
     for (seed <- defaultSeed to defaultSeed + 5L) {
       val samplingRate = 0.1
-      checkAllCombos(stratifiedData, samplingRate, seed, n)
+      StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, seed, 
n)
     }
 
     // vary sampling rate
     for (samplingRate <- List(0.01, 0.05, 0.1, 0.5)) {
-      checkAllCombos(stratifiedData, samplingRate, defaultSeed, n)
+      StratifiedAuxiliary.testSampleExact(stratifiedData, samplingRate, 
defaultSeed, n)
     }
   }
 
@@ -556,6 +519,98 @@ class PairRDDFunctionsSuite extends FunSuite with 
SharedSparkContext {
     intercept[IllegalArgumentException] {shuffled.lookup(-1)}
   }
 
+  private object StratifiedAuxiliary {
+    def stratifier (fractionPositive: Double) = {
+      (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0"
+    }
+
+    def checkSize(exact: Boolean,
+        withReplacement: Boolean,
+        expected: Long,
+        actual: Long,
+        p: Double): Boolean = {
+      if (exact) {
+        return expected == actual
+      }
+      val stdev = if (withReplacement) math.sqrt(expected) else 
math.sqrt(expected * p * (1 - p))
+      // Very forgiving margin since we're dealing with very small sample 
sizes most of the time
+      math.abs(actual - expected) <= 6 * stdev
+    }
+
+    def testSampleExact(stratifiedData: RDD[(String, Int)],
+        samplingRate: Double,
+        seed: Long,
+        n: Long) = {
+      testBernoulli(stratifiedData, true, samplingRate, seed, n)
+      testPoisson(stratifiedData, true, samplingRate, seed, n)
+    }
+
+    def testSample(stratifiedData: RDD[(String, Int)],
+        samplingRate: Double,
+        seed: Long,
+        n: Long) = {
+      testBernoulli(stratifiedData, false, samplingRate, seed, n)
+      testPoisson(stratifiedData, false, samplingRate, seed, n)
+    }
+
+    // Without replacement validation
+    def testBernoulli(stratifiedData: RDD[(String, Int)],
+        exact: Boolean,
+        samplingRate: Double,
+        seed: Long,
+        n: Long) = {
+      val expectedSampleSize = stratifiedData.countByKey()
+        .mapValues(count => math.ceil(count * samplingRate).toInt)
+      val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
+      val sample = if (exact) {
+        stratifiedData.sampleByKeyExact(false, fractions, seed)
+      } else {
+        stratifiedData.sampleByKey(false, fractions, seed)
+      }
+      val sampleCounts = sample.countByKey()
+      val takeSample = sample.collect()
+      sampleCounts.foreach { case(k, v) =>
+        assert(checkSize(exact, false, expectedSampleSize(k), v, 
samplingRate)) }
+      assert(takeSample.size === takeSample.toSet.size)
+      takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not 
in [1, $n]") }
+    }
+
+    // With replacement validation
+    def testPoisson(stratifiedData: RDD[(String, Int)],
+        exact: Boolean,
+        samplingRate: Double,
+        seed: Long,
+        n: Long) = {
+      val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
+        math.ceil(count * samplingRate).toInt)
+      val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
+      val sample = if (exact) {
+        stratifiedData.sampleByKeyExact(true, fractions, seed)
+      } else {
+        stratifiedData.sampleByKey(true, fractions, seed)
+      }
+      val sampleCounts = sample.countByKey()
+      val takeSample = sample.collect()
+      sampleCounts.foreach { case (k, v) =>
+        assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate))
+      }
+      val groupedByKey = takeSample.groupBy(_._1)
+      for ((key, v) <- groupedByKey) {
+        if (expectedSampleSize(key) >= 100 && samplingRate >= 0.1) {
+          // sample large enough for there to be repeats with high likelihood
+          assert(v.toSet.size < expectedSampleSize(key))
+        } else {
+          if (exact) {
+            assert(v.toSet.size <= expectedSampleSize(key))
+          } else {
+            assert(checkSize(false, true, expectedSampleSize(key), 
v.toSet.size, samplingRate))
+          }
+        }
+      }
+      takeSample.foreach(x => assert(1 <= x._2 && x._2 <= n, s"elements not in 
[1, $n]"))
+    }
+  }
+
 }
 
 /*


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to