Repository: spark
Updated Branches:
  refs/heads/master d3638d7bf -> f6066b0c3


[SPARK-11730][ML] Add feature importances for GBTs.

## What changes were proposed in this pull request?

Now that GBTs have been moved to ML, they can use the implementation of feature 
importance for random forests. This patch simply adds a `featureImportances` 
attribute to `GBTClassifier` and `GBTRegressor` and adds tests for each.

GBT feature importances here simply average the feature importances for each 
tree in its ensemble. This follows the implementation from scikit-learn. This 
method is also suggested by J Friedman in [this 
paper](https://statweb.stanford.edu/~jhf/ftp/trebst.pdf).

## How was this patch tested?

Unit tests were added to `GBTClassifierSuite` and `GBTRegressorSuite` to 
validate feature importances.

Author: sethah <[email protected]>

Closes #11961 from sethah/SPARK-11730.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f6066b0c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f6066b0c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f6066b0c

Branch: refs/heads/master
Commit: f6066b0c3c35ceea1706378145e15776c9b4415a
Parents: d3638d7
Author: sethah <[email protected]>
Authored: Mon Mar 28 22:27:53 2016 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Mon Mar 28 22:27:53 2016 -0700

----------------------------------------------------------------------
 .../classification/DecisionTreeClassifier.scala |   2 +-
 .../spark/ml/classification/GBTClassifier.scala |  13 ++
 .../classification/RandomForestClassifier.scala |  16 +--
 .../ml/regression/DecisionTreeRegressor.scala   |   2 +-
 .../spark/ml/regression/GBTRegressor.scala      |  13 ++
 .../ml/regression/RandomForestRegressor.scala   |  16 +--
 .../ml/tree/impl/GradientBoostedTrees.scala     |   2 +
 .../spark/ml/tree/impl/RandomForest.scala       | 110 -----------------
 .../org/apache/spark/ml/tree/treeModels.scala   | 120 +++++++++++++++++++
 .../ml/classification/GBTClassifierSuite.scala  |  25 ++++
 .../spark/ml/regression/GBTRegressorSuite.scala |  23 ++++
 .../spark/ml/tree/impl/RandomForestSuite.scala  |   6 +-
 12 files changed, 213 insertions(+), 135 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f6066b0c/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 3e4b21b..23c4af1 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -203,7 +203,7 @@ final class DecisionTreeClassificationModel private[ml] (
    *       to determine feature importance instead.
    */
   @Since("2.0.0")
-  lazy val featureImportances: Vector = RandomForest.featureImportances(this, 
numFeatures)
+  lazy val featureImportances: Vector = 
TreeEnsembleModel.featureImportances(this, numFeatures)
 
   /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
   override private[spark] def toOld: OldDecisionTreeModel = {

http://git-wip-us.apache.org/repos/asf/spark/blob/f6066b0c/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index c31df3a..48ce051 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -238,6 +238,19 @@ final class GBTClassificationModel private[ml](
     s"GBTClassificationModel (uid=$uid) with $numTrees trees"
   }
 
+  /**
+   * Estimate of the importance of each feature.
+   *
+   * Each feature's importance is the average of its importance across all 
trees in the ensemble
+   * The importance vector is normalized to sum to 1. This method is suggested 
by Hastie et al.
+   * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd 
Edition." 2001.)
+   * and follows the implementation from scikit-learn.
+   *
+   * @see [[DecisionTreeClassificationModel.featureImportances]]
+   */
+  @Since("2.0.0")
+  lazy val featureImportances: Vector = 
TreeEnsembleModel.featureImportances(trees, numFeatures)
+
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldGBTModel = {
     new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)

http://git-wip-us.apache.org/repos/asf/spark/blob/f6066b0c/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 5da04d3..82fa05a 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -222,19 +222,15 @@ final class RandomForestClassificationModel private[ml] (
   /**
    * Estimate of the importance of each feature.
    *
-   * This generalizes the idea of "Gini" importance to other losses,
-   * following the explanation of Gini importance from "Random Forests" 
documentation
-   * by Leo Breiman and Adele Cutler, and following the implementation from 
scikit-learn.
+   * Each feature's importance is the average of its importance across all 
trees in the ensemble
+   * The importance vector is normalized to sum to 1. This method is suggested 
by Hastie et al.
+   * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd 
Edition." 2001.)
+   * and follows the implementation from scikit-learn.
    *
-   * This feature importance is calculated as follows:
-   *  - Average over trees:
-   *     - importance(feature j) = sum (over nodes which split on feature j) 
of the gain,
-   *       where gain is scaled by the number of instances passing through node
-   *     - Normalize importances for tree to sum to 1.
-   *  - Normalize feature importance vector to sum to 1.
+   * @see [[DecisionTreeClassificationModel.featureImportances]]
    */
   @Since("1.5.0")
-  lazy val featureImportances: Vector = RandomForest.featureImportances(trees, 
numFeatures)
+  lazy val featureImportances: Vector = 
TreeEnsembleModel.featureImportances(trees, numFeatures)
 
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldRandomForestModel = {

http://git-wip-us.apache.org/repos/asf/spark/blob/f6066b0c/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 50ac96e..0a3d00e 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -203,7 +203,7 @@ final class DecisionTreeRegressionModel private[ml] (
    *       to determine feature importance instead.
    */
   @Since("2.0.0")
-  lazy val featureImportances: Vector = RandomForest.featureImportances(this, 
numFeatures)
+  lazy val featureImportances: Vector = 
TreeEnsembleModel.featureImportances(this, numFeatures)
 
   /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
   override private[spark] def toOld: OldDecisionTreeModel = {

http://git-wip-us.apache.org/repos/asf/spark/blob/f6066b0c/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index da5b77e..8fca35d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -224,6 +224,19 @@ final class GBTRegressionModel private[ml](
     s"GBTRegressionModel (uid=$uid) with $numTrees trees"
   }
 
+  /**
+   * Estimate of the importance of each feature.
+   *
+   * Each feature's importance is the average of its importance across all 
trees in the ensemble
+   * The importance vector is normalized to sum to 1. This method is suggested 
by Hastie et al.
+   * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd 
Edition." 2001.)
+   * and follows the implementation from scikit-learn.
+   *
+   * @see [[DecisionTreeRegressionModel.featureImportances]]
+   */
+  @Since("2.0.0")
+  lazy val featureImportances: Vector = 
TreeEnsembleModel.featureImportances(trees, numFeatures)
+
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldGBTModel = {
     new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)

http://git-wip-us.apache.org/repos/asf/spark/blob/f6066b0c/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 798947b..5b3f3a1 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -181,19 +181,15 @@ final class RandomForestRegressionModel private[ml] (
   /**
    * Estimate of the importance of each feature.
    *
-   * This generalizes the idea of "Gini" importance to other losses,
-   * following the explanation of Gini importance from "Random Forests" 
documentation
-   * by Leo Breiman and Adele Cutler, and following the implementation from 
scikit-learn.
+   * Each feature's importance is the average of its importance across all 
trees in the ensemble
+   * The importance vector is normalized to sum to 1. This method is suggested 
by Hastie et al.
+   * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd 
Edition." 2001.)
+   * and follows the implementation from scikit-learn.
    *
-   * This feature importance is calculated as follows:
-   *  - Average over trees:
-   *     - importance(feature j) = sum (over nodes which split on feature j) 
of the gain,
-   *       where gain is scaled by the number of instances passing through node
-   *     - Normalize importances for tree to sum to 1.
-   *  - Normalize feature importance vector to sum to 1.
+   * @see [[DecisionTreeRegressionModel.featureImportances]]
    */
   @Since("1.5.0")
-  lazy val featureImportances: Vector = RandomForest.featureImportances(trees, 
numFeatures)
+  lazy val featureImportances: Vector = 
TreeEnsembleModel.featureImportances(trees, numFeatures)
 
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldRandomForestModel = {

http://git-wip-us.apache.org/repos/asf/spark/blob/f6066b0c/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index 1c8a9b4..b37f4e8 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -19,7 +19,9 @@ package org.apache.spark.ml.tree.impl
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, 
DecisionTreeRegressor}
+import org.apache.spark.ml.tree.DecisionTreeModel
 import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
+import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => 
OldBoostingStrategy}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6066b0c/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 7774ae6..cccf052 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.classification.DecisionTreeClassificationModel
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
 import org.apache.spark.ml.tree._
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => 
OldStrategy}
 import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, 
DTStatsAggregator,
@@ -35,7 +34,6 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
 import org.apache.spark.mllib.tree.model.ImpurityStats
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.collection.OpenHashMap
 import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
 
 
@@ -1105,112 +1103,4 @@ private[spark] object RandomForest extends Logging {
     }
   }
 
-  /**
-   * Given a Random Forest model, compute the importance of each feature.
-   * This generalizes the idea of "Gini" importance to other losses,
-   * following the explanation of Gini importance from "Random Forests" 
documentation
-   * by Leo Breiman and Adele Cutler, and following the implementation from 
scikit-learn.
-   *
-   * This feature importance is calculated as follows:
-   *  - Average over trees:
-   *     - importance(feature j) = sum (over nodes which split on feature j) 
of the gain,
-   *       where gain is scaled by the number of instances passing through node
-   *     - Normalize importances for tree to sum to 1.
-   *  - Normalize feature importance vector to sum to 1.
-   *
-   * @param trees  Unweighted forest of trees
-   * @param numFeatures  Number of features in model (even if not all are 
explicitly used by
-   *                     the model).
-   *                     If -1, then numFeatures is set based on the max 
feature index in all trees.
-   * @return  Feature importance values, of length numFeatures.
-   */
-  private[ml] def featureImportances(trees: Array[DecisionTreeModel], 
numFeatures: Int): Vector = {
-    val totalImportances = new OpenHashMap[Int, Double]()
-    trees.foreach { tree =>
-      // Aggregate feature importance vector for this tree
-      val importances = new OpenHashMap[Int, Double]()
-      computeFeatureImportance(tree.rootNode, importances)
-      // Normalize importance vector for this tree, and add it to total.
-      // TODO: In the future, also support normalizing by 
tree.rootNode.impurityStats.count?
-      val treeNorm = importances.map(_._2).sum
-      if (treeNorm != 0) {
-        importances.foreach { case (idx, impt) =>
-          val normImpt = impt / treeNorm
-          totalImportances.changeValue(idx, normImpt, _ + normImpt)
-        }
-      }
-    }
-    // Normalize importances
-    normalizeMapValues(totalImportances)
-    // Construct vector
-    val d = if (numFeatures != -1) {
-      numFeatures
-    } else {
-      // Find max feature index used in trees
-      val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
-      maxFeatureIndex + 1
-    }
-    if (d == 0) {
-      assert(totalImportances.size == 0, s"Unknown error in computing feature" 
+
-        s" importance: No splits found, but some non-zero importances.")
-    }
-    val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
-    Vectors.sparse(d, indices.toArray, values.toArray)
-  }
-
-  /**
-   * Given a Decision Tree model, compute the importance of each feature.
-   * This generalizes the idea of "Gini" importance to other losses,
-   * following the explanation of Gini importance from "Random Forests" 
documentation
-   * by Leo Breiman and Adele Cutler, and following the implementation from 
scikit-learn.
-   *
-   * This feature importance is calculated as follows:
-   *  - importance(feature j) = sum (over nodes which split on feature j) of 
the gain,
-   *    where gain is scaled by the number of instances passing through node
-   *  - Normalize importances for tree to sum to 1.
-   *
-   * @param tree  Decision tree to compute importances for.
-   * @param numFeatures  Number of features in model (even if not all are 
explicitly used by
-   *                     the model).
-   *                     If -1, then numFeatures is set based on the max 
feature index in all trees.
-   * @return  Feature importance values, of length numFeatures.
-   */
-  private[ml] def featureImportances(tree: DecisionTreeModel, numFeatures: 
Int): Vector = {
-    featureImportances(Array(tree), numFeatures)
-  }
-
-  /**
-   * Recursive method for computing feature importances for one tree.
-   * This walks down the tree, adding to the importance of 1 feature at each 
node.
-   * @param node  Current node in recursion
-   * @param importances  Aggregate feature importances, modified by this method
-   */
-  private[impl] def computeFeatureImportance(
-      node: Node,
-      importances: OpenHashMap[Int, Double]): Unit = {
-    node match {
-      case n: InternalNode =>
-        val feature = n.split.featureIndex
-        val scaledGain = n.gain * n.impurityStats.count
-        importances.changeValue(feature, scaledGain, _ + scaledGain)
-        computeFeatureImportance(n.leftChild, importances)
-        computeFeatureImportance(n.rightChild, importances)
-      case n: LeafNode =>
-        // do nothing
-    }
-  }
-
-  /**
-   * Normalize the values of this map to sum to 1, in place.
-   * If all values are 0, this method does nothing.
-   * @param map  Map with non-negative values.
-   */
-  private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
-    val total = map.map(_._2).sum
-    if (total != 0) {
-      val keys = map.iterator.map(_._1).toArray
-      keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
-    }
-  }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6066b0c/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index ef40c90..1fad9d6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -27,6 +27,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel => 
OldDecisionTreeModel}
 import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.collection.OpenHashMap
 
 /**
  * Abstraction for Decision Tree models.
@@ -115,6 +116,125 @@ private[ml] trait TreeEnsembleModel {
   lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
 }
 
+private[ml] object TreeEnsembleModel {
+
+  /**
+   * Given a tree ensemble model, compute the importance of each feature.
+   * This generalizes the idea of "Gini" importance to other losses,
+   * following the explanation of Gini importance from "Random Forests" 
documentation
+   * by Leo Breiman and Adele Cutler, and following the implementation from 
scikit-learn.
+   *
+   *  For collections of trees, including boosting and bagging, Hastie et al.
+   *  propose to use the average of single tree importances across all trees 
in the ensemble.
+   *
+   * This feature importance is calculated as follows:
+   *  - Average over trees:
+   *     - importance(feature j) = sum (over nodes which split on feature j) 
of the gain,
+   *       where gain is scaled by the number of instances passing through node
+   *     - Normalize importances for tree to sum to 1.
+   *  - Normalize feature importance vector to sum to 1.
+   *
+   *  References:
+   *  - Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 
2nd Edition." 2001.
+   *
+   * @param trees  Unweighted collection of trees
+   * @param numFeatures  Number of features in model (even if not all are 
explicitly used by
+   *                     the model).
+   *                     If -1, then numFeatures is set based on the max 
feature index in all trees.
+   * @return  Feature importance values, of length numFeatures.
+   */
+  def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): 
Vector = {
+    val totalImportances = new OpenHashMap[Int, Double]()
+    trees.foreach { tree =>
+      // Aggregate feature importance vector for this tree
+      val importances = new OpenHashMap[Int, Double]()
+      computeFeatureImportance(tree.rootNode, importances)
+      // Normalize importance vector for this tree, and add it to total.
+      // TODO: In the future, also support normalizing by 
tree.rootNode.impurityStats.count?
+      val treeNorm = importances.map(_._2).sum
+      if (treeNorm != 0) {
+        importances.foreach { case (idx, impt) =>
+          val normImpt = impt / treeNorm
+          totalImportances.changeValue(idx, normImpt, _ + normImpt)
+        }
+      }
+    }
+    // Normalize importances
+    normalizeMapValues(totalImportances)
+    // Construct vector
+    val d = if (numFeatures != -1) {
+      numFeatures
+    } else {
+      // Find max feature index used in trees
+      val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
+      maxFeatureIndex + 1
+    }
+    if (d == 0) {
+      assert(totalImportances.size == 0, s"Unknown error in computing feature" 
+
+        s" importance: No splits found, but some non-zero importances.")
+    }
+    val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
+    Vectors.sparse(d, indices.toArray, values.toArray)
+  }
+
+  /**
+   * Given a Decision Tree model, compute the importance of each feature.
+   * This generalizes the idea of "Gini" importance to other losses,
+   * following the explanation of Gini importance from "Random Forests" 
documentation
+   * by Leo Breiman and Adele Cutler, and following the implementation from 
scikit-learn.
+   *
+   * This feature importance is calculated as follows:
+   *  - importance(feature j) = sum (over nodes which split on feature j) of 
the gain,
+   *    where gain is scaled by the number of instances passing through node
+   *  - Normalize importances for tree to sum to 1.
+   *
+   * @param tree  Decision tree to compute importances for.
+   * @param numFeatures  Number of features in model (even if not all are 
explicitly used by
+   *                     the model).
+   *                     If -1, then numFeatures is set based on the max 
feature index in all trees.
+   * @return  Feature importance values, of length numFeatures.
+   */
+  def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
+    featureImportances(Array(tree), numFeatures)
+  }
+
+  /**
+   * Recursive method for computing feature importances for one tree.
+   * This walks down the tree, adding to the importance of 1 feature at each 
node.
+   *
+   * @param node  Current node in recursion
+   * @param importances  Aggregate feature importances, modified by this method
+   */
+  def computeFeatureImportance(
+      node: Node,
+      importances: OpenHashMap[Int, Double]): Unit = {
+    node match {
+      case n: InternalNode =>
+        val feature = n.split.featureIndex
+        val scaledGain = n.gain * n.impurityStats.count
+        importances.changeValue(feature, scaledGain, _ + scaledGain)
+        computeFeatureImportance(n.leftChild, importances)
+        computeFeatureImportance(n.rightChild, importances)
+      case n: LeafNode =>
+      // do nothing
+    }
+  }
+
+  /**
+   * Normalize the values of this map to sum to 1, in place.
+   * If all values are 0, this method does nothing.
+   *
+   * @param map  Map with non-negative values.
+   */
+  def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
+    val total = map.map(_._2).sum
+    if (total != 0) {
+      val keys = map.iterator.map(_._1).toArray
+      keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
+    }
+  }
+}
+
 /** Helper classes for tree model persistence */
 private[ml] object DecisionTreeModelReadWrite {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f6066b0c/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index f3680ed..bf7481e 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -121,6 +121,31 @@ class GBTClassifierSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   */
 
   /////////////////////////////////////////////////////////////////////////////
+  // Tests of feature importance
+  /////////////////////////////////////////////////////////////////////////////
+  test("Feature importance with toy data") {
+    val numClasses = 2
+    val gbt = new GBTClassifier()
+      .setImpurity("Gini")
+      .setMaxDepth(3)
+      .setMaxIter(5)
+      .setSubsamplingRate(1.0)
+      .setStepSize(0.5)
+      .setSeed(123)
+
+    // In this data, feature 1 is very important.
+    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+    val categoricalFeatures = Map.empty[Int, Int]
+    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses)
+
+    val importances = gbt.fit(df).featureImportances
+    val mostImportantFeature = importances.argmax
+    assert(mostImportantFeature === 1)
+    assert(importances.toArray.sum === 1.0)
+    assert(importances.toArray.forall(_ >= 0.0))
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f6066b0c/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 84148a8..dfb8418 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -132,6 +132,29 @@ class GBTRegressorSuite extends SparkFunSuite with 
MLlibTestSparkContext {
   */
 
   /////////////////////////////////////////////////////////////////////////////
+  // Tests of feature importance
+  /////////////////////////////////////////////////////////////////////////////
+  test("Feature importance with toy data") {
+    val gbt = new GBTRegressor()
+      .setMaxDepth(3)
+      .setMaxIter(5)
+      .setSubsamplingRate(1.0)
+      .setStepSize(0.5)
+      .setSeed(123)
+
+    // In this data, feature 1 is very important.
+    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+    val categoricalFeatures = Map.empty[Int, Int]
+    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+    val importances = gbt.fit(df).featureImportances
+    val mostImportantFeature = importances.argmax
+    assert(mostImportantFeature === 1)
+    assert(importances.toArray.sum === 1.0)
+    assert(importances.toArray.forall(_ >= 0.0))
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f6066b0c/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 361366f..441338e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -471,7 +471,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     // Test feature importance computed at different subtrees.
     def testNode(node: Node, expected: Map[Int, Double]): Unit = {
       val map = new OpenHashMap[Int, Double]()
-      RandomForest.computeFeatureImportance(node, map)
+      TreeEnsembleModel.computeFeatureImportance(node, map)
       assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
     }
 
@@ -493,7 +493,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 
3)
         .asInstanceOf[DecisionTreeModel]
     }
-    val importances: Vector = RandomForest.featureImportances(trees, 2)
+    val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2)
     val tree2norm = feature0importance + feature1importance
     val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
       (feature1importance / tree2norm) / 2.0)
@@ -504,7 +504,7 @@ class RandomForestSuite extends SparkFunSuite with 
MLlibTestSparkContext {
     val map = new OpenHashMap[Int, Double]()
     map(0) = 1.0
     map(2) = 2.0
-    RandomForest.normalizeMapValues(map)
+    TreeEnsembleModel.normalizeMapValues(map)
     val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
     assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to