Repository: spark Updated Branches: refs/heads/master 235865754 -> cfff397f0
[SPARK-6004][MLlib] Pick the best model when training GradientBoostedTrees with validation Since the validation error does not change monotonically, in practice, it should be proper to pick the best model when training GradientBoostedTrees with validation instead of stopping it early. Author: Liang-Chi Hsieh <[email protected]> Closes #4763 from viirya/gbt_record_model and squashes the following commits: 452e049 [Liang-Chi Hsieh] Address comment. ea2fae2 [Liang-Chi Hsieh] Pick the best model when training GradientBoostedTrees with validation. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cfff397f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cfff397f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cfff397f Branch: refs/heads/master Commit: cfff397f0adb27ca102cca43a7696e9fb1819ee0 Parents: 2358657 Author: Liang-Chi Hsieh <[email protected]> Authored: Thu Feb 26 10:51:47 2015 -0800 Committer: Joseph K. Bradley <[email protected]> Committed: Thu Feb 26 10:51:47 2015 -0800 ---------------------------------------------------------------------- .../apache/spark/mllib/tree/GradientBoostedTrees.scala | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/cfff397f/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index b4466ff..a9c93e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -251,9 +251,15 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") - - new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + if (validate) { + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, + baseLearners.slice(0, bestM), + baseLearnerWeights.slice(0, bestM)) + } else { + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + } } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
