Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/1866#discussion_r16032166
--- Diff: core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
---
@@ -197,33 +197,57 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Return a subset of this RDD sampled by key (via stratified sampling).
*
* Create a sample of this RDD using variable sampling rates for
different keys as specified by
- * `fractions`, a key to sampling rate map.
- *
- * If `exact` is set to false, create the sample via simple random
sampling, with one pass
- * over the RDD, to produce a sample of size that's approximately equal
to the sum of
- * math.ceil(numItems * samplingRate) over all key values; otherwise, use
- * additional passes over the RDD to create a sample size that's exactly
equal to the sum of
- * math.ceil(numItems * samplingRate) over all key values with a 99.99%
confidence. When sampling
- * without replacement, we need one additional pass over the RDD to
guarantee sample size;
- * when sampling with replacement, we need two additional passes.
+ * `fractions`, a key to sampling rate map, via simple random sampling
with one pass over the
+ * RDD, to produce a sample of size that's approximately equal to the
sum of
+ * math.ceil(numItems * samplingRate) over all key values.
*
* @param withReplacement whether to sample with or without replacement
* @param fractions map of specific keys to sampling rates
* @param seed seed for the random number generator
- * @param exact whether sample size needs to be exactly
math.ceil(fraction * size) per key
* @return RDD containing the sampled subset
*/
def sampleByKey(withReplacement: Boolean,
fractions: Map[K, Double],
- exact: Boolean = false,
- seed: Long = Utils.random.nextLong): RDD[(K, V)]= {
+ seed: Long = Utils.random.nextLong): RDD[(K, V)] = {
+
+ require(fractions.values.forall(v => v >= 0.0), "Negative sampling
rates.")
+
+ val samplingFunc = if (withReplacement) {
+ StratifiedSamplingUtils.getPoissonSamplingFunction(self, fractions,
false, seed)
+ } else {
+ StratifiedSamplingUtils.getBernoulliSamplingFunction(self,
fractions, false, seed)
+ }
+ self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
+ }
+
+ /**
+ * ::Experimental::
+ *
--- End diff --
ditto: remove this line
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]