This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5d8a934  [SPARK-26721][ML] Avoid per-tree normalization in 
featureImportance for GBT
5d8a934 is described below

commit 5d8a934c13420dcce9d68cbf1f5f30381978d32e
Author: Marco Gaido <marcogaid...@gmail.com>
AuthorDate: Sat Feb 16 16:51:01 2019 -0600

    [SPARK-26721][ML] Avoid per-tree normalization in featureImportance for GBT
    
    ## What changes were proposed in this pull request?
    
    Our feature importance calculation is taken from sklearn's one, which has 
been recently fixed (in 
https://github.com/scikit-learn/scikit-learn/pull/11176). Citing the 
description of that PR:
    
    > Because the feature importances are (currently, by default) normalized 
and then averaged, feature importances from later stages are overweighted.
    
    The PR performs a fix similar to sklearn's one. The per-tree normalization 
of the feature importance is skipped and GBT.
    
    Credits for pointing out clearly the issue and the sklearn's PR to Daniel 
Jumper.
    
    ## How was this patch tested?
    
    modified UT, checked that the computed `featureImportance` in that test is 
similar to sklearn's one (ti can't be the same, because the trees may be 
slightly different)
    
    Closes #23773 from mgaido91/SPARK-26721.
    
    Authored-by: Marco Gaido <marcogaid...@gmail.com>
    Signed-off-by: Sean Owen <sean.o...@databricks.com>
---
 .../spark/ml/classification/GBTClassifier.scala    |  5 +++--
 .../apache/spark/ml/regression/GBTRegressor.scala  |  3 ++-
 .../org/apache/spark/ml/tree/treeModels.scala      | 23 ++++++++++++++++++----
 .../ml/classification/GBTClassifierSuite.scala     |  3 ++-
 .../spark/ml/regression/GBTRegressorSuite.scala    |  3 ++-
 5 files changed, 28 insertions(+), 9 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index abe2d1f..a5ed4a3 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -341,11 +341,12 @@ class GBTClassificationModel private[ml](
    * The importance vector is normalized to sum to 1. This method is suggested 
by Hastie et al.
    * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd 
Edition." 2001.)
    * and follows the implementation from scikit-learn.
-
+   *
    * See `DecisionTreeClassificationModel.featureImportances`
    */
   @Since("2.0.0")
-  lazy val featureImportances: Vector = 
TreeEnsembleModel.featureImportances(trees, numFeatures)
+  lazy val featureImportances: Vector =
+    TreeEnsembleModel.featureImportances(trees, numFeatures, 
perTreeNormalization = false)
 
   /** Raw prediction for the positive class. */
   private def margin(features: Vector): Double = {
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 9a5b7d5..9f0f567 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -285,7 +285,8 @@ class GBTRegressionModel private[ml](
    * @see `DecisionTreeRegressionModel.featureImportances`
    */
   @Since("2.0.0")
-  lazy val featureImportances: Vector = 
TreeEnsembleModel.featureImportances(trees, numFeatures)
+  lazy val featureImportances: Vector =
+    TreeEnsembleModel.featureImportances(trees, numFeatures, 
perTreeNormalization = false)
 
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldGBTModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 51d5d5c..e95c55f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -135,7 +135,7 @@ private[ml] object TreeEnsembleModel {
    *  - Average over trees:
    *     - importance(feature j) = sum (over nodes which split on feature j) 
of the gain,
    *       where gain is scaled by the number of instances passing through node
-   *     - Normalize importances for tree to sum to 1.
+   *     - Normalize importances for tree to sum to 1 (only if 
`perTreeNormalization` is `true`).
    *  - Normalize feature importance vector to sum to 1.
    *
    *  References:
@@ -145,9 +145,15 @@ private[ml] object TreeEnsembleModel {
    * @param numFeatures  Number of features in model (even if not all are 
explicitly used by
    *                     the model).
    *                     If -1, then numFeatures is set based on the max 
feature index in all trees.
+   * @param perTreeNormalization By default this is set to `true` and it means 
that the importances
+   *                             of each tree are normalized before being 
summed. If set to `false`,
+   *                             the normalization is skipped.
    * @return  Feature importance values, of length numFeatures.
    */
-  def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: 
Int): Vector = {
+  def featureImportances[M <: DecisionTreeModel](
+      trees: Array[M],
+      numFeatures: Int,
+      perTreeNormalization: Boolean = true): Vector = {
     val totalImportances = new OpenHashMap[Int, Double]()
     trees.foreach { tree =>
       // Aggregate feature importance vector for this tree
@@ -155,10 +161,19 @@ private[ml] object TreeEnsembleModel {
       computeFeatureImportance(tree.rootNode, importances)
       // Normalize importance vector for this tree, and add it to total.
       // TODO: In the future, also support normalizing by 
tree.rootNode.impurityStats.count?
-      val treeNorm = importances.map(_._2).sum
+      val treeNorm = if (perTreeNormalization) {
+        importances.map(_._2).sum
+      } else {
+        // We won't use it
+        Double.NaN
+      }
       if (treeNorm != 0) {
         importances.foreach { case (idx, impt) =>
-          val normImpt = impt / treeNorm
+          val normImpt = if (perTreeNormalization) {
+            impt / treeNorm
+          } else {
+            impt
+          }
           totalImportances.changeValue(idx, normImpt, _ + normImpt)
         }
       }
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index cedbaf1..cd59900 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -363,7 +363,8 @@ class GBTClassifierSuite extends MLTest with 
DefaultReadWriteTest {
     val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1")
     val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
     val mostIF = importanceFeatures.argmax
-    assert(mostImportantFeature !== mostIF)
+    assert(mostIF === 1)
+    assert(importances(mostImportantFeature) !== importanceFeatures(mostIF))
   }
 
   test("model evaluateEachIteration") {
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index b145c7a..46fa376 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -200,7 +200,8 @@ class GBTRegressorSuite extends MLTest with 
DefaultReadWriteTest {
     val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1")
     val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances
     val mostIF = importanceFeatures.argmax
-    assert(mostImportantFeature !== mostIF)
+    assert(mostIF === 1)
+    assert(importances(mostImportantFeature) !== importanceFeatures(mostIF))
   }
 
   test("model evaluateEachIteration") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to