Repository: spark Updated Branches: refs/heads/branch-0.9 1dc1e988f -> af7e8b1c9
SPARK-1240: handle the case of empty RDD when takeSample https://spark-project.atlassian.net/browse/SPARK-1240 It seems that the current implementation does not handle the empty RDD case when run takeSample In this patch, before calling sample() inside takeSample API, I add a checker for this case and returns an empty Array when it's a empty RDD; also in sample(), I add a checker for the invalid fraction value In the test case, I also add several lines for this case Author: CodingCat <[email protected]> Closes #135 from CodingCat/SPARK-1240 and squashes the following commits: fef57d4 [CodingCat] fix the same problem in PySpark 36db06b [CodingCat] create new test cases for takeSample from an empty red 810948d [CodingCat] further fix a40e8fb [CodingCat] replace if with require ad483fd [CodingCat] handle the case with empty RDD when take sample Conflicts: core/src/main/scala/org/apache/spark/rdd/RDD.scala core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/af7e8b1c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/af7e8b1c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/af7e8b1c Branch: refs/heads/branch-0.9 Commit: af7e8b1c9913ebb6d4131a03cfe0cd0a2b38c529 Parents: 1dc1e98 Author: CodingCat <[email protected]> Authored: Sun Mar 16 22:14:59 2014 -0700 Committer: Matei Zaharia <[email protected]> Committed: Sun Mar 16 22:40:22 2014 -0700 ---------------------------------------------------------------------- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 10 ++++++++-- core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 7 +++++++ python/pyspark/rdd.py | 4 ++++ 3 files changed, 19 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/af7e8b1c/core/src/main/scala/org/apache/spark/rdd/RDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 1472c92..b529754 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -319,8 +319,10 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = + def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = { + require(fraction >= 0.0, "Invalid fraction value: " + fraction) new SampledRDD(this, withReplacement, fraction, seed) + } def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = { var fraction = 0.0 @@ -333,6 +335,10 @@ abstract class RDD[T: ClassTag]( throw new IllegalArgumentException("Negative number of elements requested") } + if (initialCount == 0) { + return new Array[T](0) + } + if (initialCount > Integer.MAX_VALUE - 1) { maxSelected = Integer.MAX_VALUE - 1 } else { @@ -351,7 +357,7 @@ abstract class RDD[T: ClassTag]( var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for thei initial size + // this shouldn't happen often because we use a big multiplier for the initial size while (samples.length < total) { samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() } http://git-wip-us.apache.org/repos/asf/spark/blob/af7e8b1c/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 559ea05..ac9df34 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -455,6 +455,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("takeSample") { val data = sc.parallelize(1 to 100, 2) + for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements @@ -486,6 +487,12 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("takeSample from an empty rdd") { + val emptySet = sc.parallelize(Seq.empty[Int], 2) + val sample = emptySet.takeSample(false, 20, 1) + assert(sample.length === 0) + } + test("runJob on an invalid partition") { intercept[IllegalArgumentException] { sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) http://git-wip-us.apache.org/repos/asf/spark/blob/af7e8b1c/python/pyspark/rdd.py ---------------------------------------------------------------------- diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c29cefa..6d05ff2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -254,6 +254,7 @@ class RDD(object): >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] """ + assert fraction >= 0.0, "Invalid fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) # this is ported from scala/spark/RDD.scala @@ -274,6 +275,9 @@ class RDD(object): if (num < 0): raise ValueError + if (initialCount == 0): + return list() + if initialCount > sys.maxint - 1: maxSelected = sys.maxint - 1 else:
