Repository: spark
Updated Branches:
  refs/heads/master 235865754 -> cfff397f0


[SPARK-6004][MLlib] Pick the best model when training GradientBoostedTrees with 
validation

Since the validation error does not change monotonically, in practice, it 
should be proper to pick the best model when training GradientBoostedTrees with 
validation instead of stopping it early.

Author: Liang-Chi Hsieh <[email protected]>

Closes #4763 from viirya/gbt_record_model and squashes the following commits:

452e049 [Liang-Chi Hsieh] Address comment.
ea2fae2 [Liang-Chi Hsieh] Pick the best model when training 
GradientBoostedTrees with validation.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cfff397f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cfff397f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cfff397f

Branch: refs/heads/master
Commit: cfff397f0adb27ca102cca43a7696e9fb1819ee0
Parents: 2358657
Author: Liang-Chi Hsieh <[email protected]>
Authored: Thu Feb 26 10:51:47 2015 -0800
Committer: Joseph K. Bradley <[email protected]>
Committed: Thu Feb 26 10:51:47 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/mllib/tree/GradientBoostedTrees.scala  | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cfff397f/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index b4466ff..a9c93e1 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -251,9 +251,15 @@ object GradientBoostedTrees extends Logging {
 
     logInfo("Internal timing for DecisionTree:")
     logInfo(s"$timer")
-
-    new GradientBoostedTreesModel(
-      boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
+    if (validate) {
+      new GradientBoostedTreesModel(
+        boostingStrategy.treeStrategy.algo,
+        baseLearners.slice(0, bestM),
+        baseLearnerWeights.slice(0, bestM))
+    } else {
+      new GradientBoostedTreesModel(
+        boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
+    }
   }
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to