Repository: spark Updated Branches: refs/heads/master 6bb60b30f -> 1bb60ab83
[SPARK-26153][ML] GBT & RandomForest avoid unnecessary `first` job to compute `numFeatures` ## What changes were proposed in this pull request? use base models' `numFeature` instead of `first` job ## How was this patch tested? existing tests Closes #23123 from zhengruifeng/avoid_first_job. Authored-by: zhengruifeng <[email protected]> Signed-off-by: Sean Owen <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1bb60ab8 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1bb60ab8 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1bb60ab8 Branch: refs/heads/master Commit: 1bb60ab8392adf8b896cc04fb1d060620cf09d8a Parents: 6bb60b3 Author: zhengruifeng <[email protected]> Authored: Mon Nov 26 05:57:33 2018 -0600 Committer: Sean Owen <[email protected]> Committed: Mon Nov 26 05:57:33 2018 -0600 ---------------------------------------------------------------------- .../org/apache/spark/ml/classification/GBTClassifier.scala | 5 +++-- .../spark/ml/classification/RandomForestClassifier.scala | 2 +- .../scala/org/apache/spark/ml/regression/GBTRegressor.scala | 6 ++++-- .../org/apache/spark/ml/regression/RandomForestRegressor.scala | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1bb60ab8/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index fab8155..09a9df6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -180,7 +180,6 @@ class GBTClassifier @Since("1.4.0") ( (convert2LabeledPoint(dataset), null) } - val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification) val numClasses = 2 @@ -196,7 +195,6 @@ class GBTClassifier @Since("1.4.0") ( maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy, validationIndicatorCol) - instr.logNumFeatures(numFeatures) instr.logNumClasses(numClasses) val (baseLearners, learnerWeights) = if (withValidation) { @@ -206,6 +204,9 @@ class GBTClassifier @Since("1.4.0") ( GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } + val numFeatures = baseLearners.head.numFeatures + instr.logNumFeatures(numFeatures) + new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures) } http://git-wip-us.apache.org/repos/asf/spark/blob/1bb60ab8/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 05fff88..0a3bfd1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -142,7 +142,7 @@ class RandomForestClassifier @Since("1.4.0") ( .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeClassificationModel]) - val numFeatures = oldDataset.first().features.size + val numFeatures = trees.head.numFeatures instr.logNumClasses(numClasses) instr.logNumFeatures(numFeatures) new RandomForestClassificationModel(uid, trees, numFeatures, numClasses) http://git-wip-us.apache.org/repos/asf/spark/blob/1bb60ab8/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 186fa23..9b386ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -165,7 +165,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) } else { (extractLabeledPoints(dataset), null) } - val numFeatures = trainDataset.first().features.size val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression) instr.logPipelineStage(this) @@ -173,7 +172,6 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, lossType, maxDepth, maxBins, maxIter, maxMemoryInMB, minInfoGain, minInstancesPerNode, seed, stepSize, subsamplingRate, cacheNodeIds, checkpointInterval, featureSubsetStrategy) - instr.logNumFeatures(numFeatures) val (baseLearners, learnerWeights) = if (withValidation) { GradientBoostedTrees.runWithValidation(trainDataset, validationDataset, boostingStrategy, @@ -182,6 +180,10 @@ class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) GradientBoostedTrees.run(trainDataset, boostingStrategy, $(seed), $(featureSubsetStrategy)) } + + val numFeatures = baseLearners.head.numFeatures + instr.logNumFeatures(numFeatures) + new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures) } http://git-wip-us.apache.org/repos/asf/spark/blob/1bb60ab8/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 7f5e668..afa9a64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -133,7 +133,7 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr)) .map(_.asInstanceOf[DecisionTreeRegressionModel]) - val numFeatures = oldDataset.first().features.size + val numFeatures = trees.head.numFeatures instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures) new RandomForestRegressionModel(uid, trees, numFeatures) } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
