Github user jkbradley commented on a diff in the pull request:
https://github.com/apache/spark/pull/19433#discussion_r150349566
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
@@ -852,6 +662,41 @@ private[spark] object RandomForest extends Logging {
}
/**
+ * Find the best split for a node.
+ *
+ * @param binAggregates Bin statistics.
+ * @return tuple for best split: (Split, information gain, prediction at
node)
+ */
+ private[tree] def binsToBestSplit(
+ binAggregates: DTStatsAggregator,
+ splits: Array[Array[Split]],
+ featuresForNode: Option[Array[Int]],
+ node: LearningNode): (Split, ImpurityStats) = {
+ val validFeatureSplits =
getNonConstantFeatures(binAggregates.metadata, featuresForNode)
+ // For each (feature, split), calculate the gain, and select the best
(feature, split).
+ val parentImpurityCalc = if (node.stats == null) None else
Some(node.stats.impurityCalculator)
--- End diff --
Note to check: Will node.stats == null for the top level for sure?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]