Github user dusenberrymw commented on a diff in the pull request:

    https://github.com/apache/spark/pull/8112#discussion_r40272553
  
    --- Diff: core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala 
---
    @@ -263,6 +263,80 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
       }
     
       /**
    +   * ::Experimental::
    +   * Return random, non-overlapping splits of this RDD sampled by key (via 
stratified sampling)
    +   * with each split containing exactly math.ceil(numItems * samplingRate) 
for each stratum.
    +   *
    +   * This method differs from [[sampleByKey]] and [[sampleByKeyExact]] in 
that it provides random
    +   * splits (and their complements) instead of just a subsample of the 
data. This requires
    +   * segmenting random keys into ranges with upper and lower bounds 
instead of segmenting the keys
    +   * into a high/low bisection of the entire dataset.
    +   *
    +   * @param weights array of maps of (key -> samplingRate) pairs for each 
split, normed by key
    +   * @param exact boolean specifying whether to use exact subsampling
    +   * @param seed seed for the random number generator
    +   * @return array of tuples containing the subsample and complement RDDs 
for each split
    +   */
    +  @Experimental
    +  def randomSplitByKey(
    +     weights: Array[Map[K, Double]],
    +     exact: Boolean = false,
    +     seed: Long = Utils.random.nextLong): Array[(RDD[(K, V)], RDD[(K, 
V)])] = self.withScope {
    +
    +    require(weights.flatMap(_.values).forall(v => v >= 0.0), "Negative 
sampling rates.")
    +    if (weights.length > 1) {
    +      require(weights.map(m => m.keys.toList).sliding(2).forall(t => t(0) 
== t(1)),
    +        "Inconsistent keys between splits.")
    +    }
    +
    +    // normalize and cumulative sum
    +    val baseFold = weights(0).map(x => (x._1, 0.0))
    +    val cumWeightsByKey = weights.scanLeft(baseFold){ case (accMap, 
iterMap) =>
    +      accMap.map { case (k, v) => (k, v + iterMap(k)) }
    +    }.drop(1)
    +
    +    val weightSumsByKey = cumWeightsByKey.last
    +    val normedCumWeightsByKey = cumWeightsByKey.dropRight(1).map(_.map { 
case (key, threshold) =>
    +      (key, threshold / weightSumsByKey(key))
    +    })
    +
    +    // compute exact thresholds for each stratum if required
    +    val splits = if (exact) {
    --- End diff --
    
    I would suggest renaming `splits` here to indicate that this contains just 
the __points__ at which to split the data, rather than also containing the left 
(0) and right (1) endpoints.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to