Github user sethah commented on a diff in the pull request:
https://github.com/apache/spark/pull/20472#discussion_r169386551
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
@@ -1001,11 +996,18 @@ private[spark] object RandomForest extends Logging {
} else {
val numSplits = metadata.numSplits(featureIndex)
- // get count for each distinct value
- val (valueCountMap, numSamples) =
featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
+ // get count for each distinct value except zero value
+ val (partValueCountMap, partNumSamples) =
featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
case ((m, cnt), x) =>
(m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
}
+
+ // Calculate the number of samples for finding splits
+ val numSamples: Int = (samplesFractionForFindSplits(metadata) *
metadata.numExamples).toInt
+
+ // add zero value count and get complete statistics
+ val valueCountMap: Map[Double, Int] = partValueCountMap + (0.0 ->
(numSamples - partNumSamples))
--- End diff --
There can be negative values right?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]