Repository: spark Updated Branches: refs/heads/master 7f87ab981 -> 586e716e4
Reservoir sampling implementation. This is going to be used in https://issues.apache.org/jira/browse/SPARK-2568 Author: Reynold Xin <[email protected]> Closes #1478 from rxin/reservoirSample and squashes the following commits: 17bcbf3 [Reynold Xin] Added seed. badf20d [Reynold Xin] Renamed the method. 6940010 [Reynold Xin] Reservoir sampling implementation. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/586e716e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/586e716e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/586e716e Branch: refs/heads/master Commit: 586e716e47305cd7c2c3ff35c0e828b63ef2f6a8 Parents: 7f87ab9 Author: Reynold Xin <[email protected]> Authored: Fri Jul 18 12:41:50 2014 -0700 Committer: Reynold Xin <[email protected]> Committed: Fri Jul 18 12:41:50 2014 -0700 ---------------------------------------------------------------------- .../spark/util/random/SamplingUtils.scala | 46 ++++++++++++++++++++ .../spark/util/random/SamplingUtilsSuite.scala | 21 +++++++++ 2 files changed, 67 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/586e716e/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index a79e3ee..d10141b 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -17,9 +17,55 @@ package org.apache.spark.util.random +import scala.reflect.ClassTag +import scala.util.Random + private[spark] object SamplingUtils { /** + * Reservoir sampling implementation that also returns the input size. + * + * @param input input size + * @param k reservoir size + * @param seed random seed + * @return (samples, input size) + */ + def reservoirSampleAndCount[T: ClassTag]( + input: Iterator[T], + k: Int, + seed: Long = Random.nextLong()) + : (Array[T], Int) = { + val reservoir = new Array[T](k) + // Put the first k elements in the reservoir. + var i = 0 + while (i < k && input.hasNext) { + val item = input.next() + reservoir(i) = item + i += 1 + } + + // If we have consumed all the elements, return them. Otherwise do the replacement. + if (i < k) { + // If input size < k, trim the array to return only an array of input size. + val trimReservoir = new Array[T](i) + System.arraycopy(reservoir, 0, trimReservoir, 0, i) + (trimReservoir, i) + } else { + // If input size > k, continue the sampling process. + val rand = new XORShiftRandom(seed) + while (input.hasNext) { + val item = input.next() + val replacementIndex = rand.nextInt(i) + if (replacementIndex < k) { + reservoir(replacementIndex) = item + } + i += 1 + } + (reservoir, i) + } + } + + /** * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of * the time. * http://git-wip-us.apache.org/repos/asf/spark/blob/586e716e/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala index accfe2e..73a9d02 100644 --- a/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala @@ -17,11 +17,32 @@ package org.apache.spark.util.random +import scala.util.Random + import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} import org.scalatest.FunSuite class SamplingUtilsSuite extends FunSuite { + test("reservoirSampleAndCount") { + val input = Seq.fill(100)(Random.nextInt()) + + // input size < k + val (sample1, count1) = SamplingUtils.reservoirSampleAndCount(input.iterator, 150) + assert(count1 === 100) + assert(input === sample1.toSeq) + + // input size == k + val (sample2, count2) = SamplingUtils.reservoirSampleAndCount(input.iterator, 100) + assert(count2 === 100) + assert(input === sample2.toSeq) + + // input size > k + val (sample3, count3) = SamplingUtils.reservoirSampleAndCount(input.iterator, 10) + assert(count3 === 100) + assert(sample3.length === 10) + } + test("computeFraction") { // test that the computed fraction guarantees enough data points // in the sample with a failure rate <= 0.0001
