Github user smurching commented on a diff in the pull request:
https://github.com/apache/spark/pull/19433#discussion_r151011913
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala ---
@@ -627,221 +621,37 @@ private[spark] object RandomForest extends Logging {
}
/**
- * Calculate the impurity statistics for a given (feature, split) based
upon left/right
- * aggregates.
- *
- * @param stats the recycle impurity statistics for this feature's all
splits,
- * only 'impurity' and 'impurityCalculator' are valid
between each iteration
- * @param leftImpurityCalculator left node aggregates for this (feature,
split)
- * @param rightImpurityCalculator right node aggregate for this
(feature, split)
- * @param metadata learning and dataset metadata for DecisionTree
- * @return Impurity statistics for this (feature, split)
+ * Return a list of pairs (featureIndexIdx, featureIndex) where
featureIndex is the global
+ * (across all trees) index of a feature and featureIndexIdx is the
index of a feature within the
+ * list of features for a given node. Filters out constant features
(features with 0 splits)
*/
- private def calculateImpurityStats(
- stats: ImpurityStats,
- leftImpurityCalculator: ImpurityCalculator,
- rightImpurityCalculator: ImpurityCalculator,
- metadata: DecisionTreeMetadata): ImpurityStats = {
-
- val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
- leftImpurityCalculator.copy.add(rightImpurityCalculator)
- } else {
- stats.impurityCalculator
- }
-
- val impurity: Double = if (stats == null) {
- parentImpurityCalculator.calculate()
- } else {
- stats.impurity
- }
-
- val leftCount = leftImpurityCalculator.count
- val rightCount = rightImpurityCalculator.count
-
- val totalCount = leftCount + rightCount
-
- // If left child or right child doesn't satisfy minimum instances per
node,
- // then this split is invalid, return invalid information gain stats.
- if ((leftCount < metadata.minInstancesPerNode) ||
- (rightCount < metadata.minInstancesPerNode)) {
- return
ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
- }
-
- val leftImpurity = leftImpurityCalculator.calculate() // Note: This
equals 0 if count = 0
- val rightImpurity = rightImpurityCalculator.calculate()
-
- val leftWeight = leftCount / totalCount.toDouble
- val rightWeight = rightCount / totalCount.toDouble
-
- val gain = impurity - leftWeight * leftImpurity - rightWeight *
rightImpurity
-
- // if information gain doesn't satisfy minimum information gain,
- // then this split is invalid, return invalid information gain stats.
- if (gain < metadata.minInfoGain) {
- return
ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
+ private[impl] def getNonConstantFeatures(
+ metadata: DecisionTreeMetadata,
+ featuresForNode: Option[Array[Int]]): Seq[(Int, Int)] = {
+ Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx =>
--- End diff --
At some point when refactoring I was hitting errors caused by a stateful
operation within a `map` over the output of this method (IIRC the result of the
`map` was accessed repeatedly, causing the stateful operation to inadvertently
be run multiple times).
However using `withFilter` and `view` now seems to work, I'll change it
back :)
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]