Github user srowen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20472#discussion_r166771387
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
    @@ -1001,11 +996,18 @@ private[spark] object RandomForest extends Logging {
         } else {
           val numSplits = metadata.numSplits(featureIndex)
     
    -      // get count for each distinct value
    -      val (valueCountMap, numSamples) = 
featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
    +      // get count for each distinct value except zero value
    +      val (partValueCountMap, partNumSamples) = 
featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
             case ((m, cnt), x) =>
               (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
           }
    +
    +      // Calculate the number of samples for finding splits
    +      val numSamples: Int = (samplesFractionForFindSplits(metadata) * 
metadata.numExamples).toInt
    +
    +      // add zero value count and get complete statistics
    +      val valueCountMap: Map[Double, Int] = partValueCountMap + (0.0 -> 
(numSamples - partNumSamples))
    --- End diff --
    
    This also probably doesn't matter but won't the new (0.0, ...) element 
always come first? you could append it below after sorting the rest rather than 
add it earlier to the Map.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to