Repository: spark
Updated Branches:
  refs/heads/master 6bb60b30f -> 1bb60ab83


[SPARK-26153][ML] GBT & RandomForest avoid unnecessary `first` job to compute 
`numFeatures`

## What changes were proposed in this pull request?
use base models' `numFeature` instead of `first` job

## How was this patch tested?
existing tests

Closes #23123 from zhengruifeng/avoid_first_job.

Authored-by: zhengruifeng <[email protected]>
Signed-off-by: Sean Owen <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1bb60ab8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1bb60ab8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1bb60ab8

Branch: refs/heads/master
Commit: 1bb60ab8392adf8b896cc04fb1d060620cf09d8a
Parents: 6bb60b3
Author: zhengruifeng <[email protected]>
Authored: Mon Nov 26 05:57:33 2018 -0600
Committer: Sean Owen <[email protected]>
Committed: Mon Nov 26 05:57:33 2018 -0600

----------------------------------------------------------------------
 .../org/apache/spark/ml/classification/GBTClassifier.scala     | 5 +++--
 .../spark/ml/classification/RandomForestClassifier.scala       | 2 +-
 .../scala/org/apache/spark/ml/regression/GBTRegressor.scala    | 6 ++++--
 .../org/apache/spark/ml/regression/RandomForestRegressor.scala | 2 +-
 4 files changed, 9 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1bb60ab8/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index fab8155..09a9df6 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -180,7 +180,6 @@ class GBTClassifier @Since("1.4.0") (
       (convert2LabeledPoint(dataset), null)
     }
 
-    val numFeatures = trainDataset.first().features.size
     val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, 
OldAlgo.Classification)
 
     val numClasses = 2
@@ -196,7 +195,6 @@ class GBTClassifier @Since("1.4.0") (
       maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, 
minInstancesPerNode,
       seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, 
featureSubsetStrategy,
       validationIndicatorCol)
-    instr.logNumFeatures(numFeatures)
     instr.logNumClasses(numClasses)
 
     val (baseLearners, learnerWeights) = if (withValidation) {
@@ -206,6 +204,9 @@ class GBTClassifier @Since("1.4.0") (
       GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), 
$(featureSubsetStrategy))
     }
 
+    val numFeatures = baseLearners.head.numFeatures
+    instr.logNumFeatures(numFeatures)
+
     new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1bb60ab8/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 05fff88..0a3bfd1 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -142,7 +142,7 @@ class RandomForestClassifier @Since("1.4.0") (
       .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, 
getSeed, Some(instr))
       .map(_.asInstanceOf[DecisionTreeClassificationModel])
 
-    val numFeatures = oldDataset.first().features.size
+    val numFeatures = trees.head.numFeatures
     instr.logNumClasses(numClasses)
     instr.logNumFeatures(numFeatures)
     new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)

http://git-wip-us.apache.org/repos/asf/spark/blob/1bb60ab8/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 186fa23..9b386ef 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -165,7 +165,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
     } else {
       (extractLabeledPoints(dataset), null)
     }
-    val numFeatures = trainDataset.first().features.size
     val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, 
OldAlgo.Regression)
 
     instr.logPipelineStage(this)
@@ -173,7 +172,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
     instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, 
lossType,
       maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, 
minInstancesPerNode,
       seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, 
featureSubsetStrategy)
-    instr.logNumFeatures(numFeatures)
 
     val (baseLearners, learnerWeights) = if (withValidation) {
       GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, 
boostingStrategy,
@@ -182,6 +180,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
       GradientBoostedTrees.run(trainDataset, boostingStrategy,
         $(seed), $(featureSubsetStrategy))
     }
+
+    val numFeatures = baseLearners.head.numFeatures
+    instr.logNumFeatures(numFeatures)
+
     new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1bb60ab8/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 7f5e668..afa9a64 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -133,7 +133,7 @@ class RandomForestRegressor @Since("1.4.0") 
(@Since("1.4.0") override val uid: S
       .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, 
getSeed, Some(instr))
       .map(_.asInstanceOf[DecisionTreeRegressionModel])
 
-    val numFeatures = oldDataset.first().features.size
+    val numFeatures = trees.head.numFeatures
     instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures)
     new RandomForestRegressionModel(uid, trees, numFeatures)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to