Github user smurching commented on a diff in the pull request:
https://github.com/apache/spark/pull/19433#discussion_r151019591
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
@@ -852,6 +662,41 @@ private[spark] object RandomForest extends Logging {
}
/**
+ * Find the best split for a node.
+ *
+ * @param binAggregates Bin statistics.
+ * @return tuple for best split: (Split, information gain, prediction at
node)
+ */
+ private[tree] def binsToBestSplit(
+ binAggregates: DTStatsAggregator,
+ splits: Array[Array[Split]],
+ featuresForNode: Option[Array[Int]],
+ node: LearningNode): (Split, ImpurityStats) = {
+ val validFeatureSplits =
getNonConstantFeatures(binAggregates.metadata, featuresForNode)
+ // For each (feature, split), calculate the gain, and select the best
(feature, split).
+ val parentImpurityCalc = if (node.stats == null) None else
Some(node.stats.impurityCalculator)
--- End diff --
I believe so, the nodes at the top level are created
([RandomForest.scala:178](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala#L178))
with
[`LearningNode.emptyNode`](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala#L341),
which sets `node.stats = null`.
I could change this to check node depth (via node index), but if we're
planning on deprecating node indices in the future it might be best not to.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]