Repository: spark
Updated Branches:
  refs/heads/master 5dfc01976 -> 69bc2c17f


[SPARK-13952][ML] Add random seed to GBT

## What changes were proposed in this pull request?

`GBTClassifier` and `GBTRegressor` should use random seed for reproducible 
results. Because of the nature of current unit tests, which compare GBTs in ML 
and GBTs in MLlib for equality, I also added a random seed to MLlib GBT 
algorithm. I made alternate constructors in `mllib.tree.GradientBoostedTrees` 
to accept a random seed, but left them as private so as to not change the API 
unnecessarily.

## How was this patch tested?

Existing unit tests verify that functionality did not change. Other ML 
algorithms do not seem to have unit tests that directly test the functionality 
of random seeding, but reproducibility with seeding for GBTs is effectively 
verified in existing tests. I can add more tests if needed.

Author: sethah <[email protected]>

Closes #11903 from sethah/SPARK-13952.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/69bc2c17
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/69bc2c17
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/69bc2c17

Branch: refs/heads/master
Commit: 69bc2c17f1ca047d4915a4791b624d60c5943dc8
Parents: 5dfc019
Author: sethah <[email protected]>
Authored: Wed Mar 23 15:08:47 2016 -0700
Committer: Joseph K. Bradley <[email protected]>
Committed: Wed Mar 23 15:08:47 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/GBTClassifier.scala |  8 ++----
 .../ml/regression/DecisionTreeRegressor.scala   |  2 +-
 .../spark/ml/regression/GBTRegressor.scala      |  8 ++----
 .../ml/tree/impl/GradientBoostedTrees.scala     | 30 ++++++++++++--------
 .../apache/spark/mllib/tree/DecisionTree.scala  | 15 ++++++++--
 .../spark/mllib/tree/GradientBoostedTrees.scala | 30 ++++++++++++++------
 .../ml/classification/GBTClassifierSuite.scala  |  4 ++-
 .../spark/ml/regression/GBTRegressorSuite.scala |  4 ++-
 .../mllib/tree/GradientBoostedTreesSuite.scala  |  4 +--
 9 files changed, 66 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/69bc2c17/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 5a8845f..c31df3a 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -96,10 +96,7 @@ final class GBTClassifier @Since("1.4.0") (
   override def setSubsamplingRate(value: Double): this.type = 
super.setSubsamplingRate(value)
 
   @Since("1.4.0")
-  override def setSeed(value: Long): this.type = {
-    logWarning("The 'seed' parameter is currently ignored by Gradient 
Boosting.")
-    super.setSeed(value)
-  }
+  override def setSeed(value: Long): this.type = super.setSeed(value)
 
   // Parameters from GBTParams:
 
@@ -158,7 +155,8 @@ final class GBTClassifier @Since("1.4.0") (
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
     val numFeatures = oldDataset.first().features.size
     val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, 
OldAlgo.Classification)
-    val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, 
boostingStrategy)
+    val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, 
boostingStrategy,
+      $(seed))
     new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/69bc2c17/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 428bc7a..fa7cc43 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -97,7 +97,7 @@ final class DecisionTreeRegressor @Since("1.4.0") 
(@Since("1.4.0") override val
   private[ml] def train(data: RDD[LabeledPoint],
       oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
     val trees = RandomForest.run(data, oldStrategy, numTrees = 1, 
featureSubsetStrategy = "all",
-      seed = 0L, parentUID = Some(uid))
+      seed = $(seed), parentUID = Some(uid))
     trees.head.asInstanceOf[DecisionTreeRegressionModel]
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/69bc2c17/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 091e1d5..da5b77e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -92,10 +92,7 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") 
override val uid: Stri
   override def setSubsamplingRate(value: Double): this.type = 
super.setSubsamplingRate(value)
 
   @Since("1.4.0")
-  override def setSeed(value: Long): this.type = {
-    logWarning("The 'seed' parameter is currently ignored by Gradient 
Boosting.")
-    super.setSeed(value)
-  }
+  override def setSeed(value: Long): this.type = super.setSeed(value)
 
   // Parameters from GBTParams:
   @Since("1.4.0")
@@ -145,7 +142,8 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") 
override val uid: Stri
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
     val numFeatures = oldDataset.first().features.size
     val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, 
OldAlgo.Regression)
-    val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, 
boostingStrategy)
+    val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, 
boostingStrategy,
+      $(seed))
     new GBTRegressionModel(uid, baseLearners, learnerWeights, numFeatures)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/69bc2c17/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index b9acc66..1c8a9b4 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -34,20 +34,23 @@ private[ml] object GradientBoostedTrees extends Logging {
   /**
    * Method to train a gradient boosting model
    * @param input Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
+   * @param seed Random seed.
    * @return tuple of ensemble models and weights:
    *         (array of decision tree models, array of model weights)
    */
-  def run(input: RDD[LabeledPoint],
-      boostingStrategy: OldBoostingStrategy
-      ): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+  def run(
+      input: RDD[LabeledPoint],
+      boostingStrategy: OldBoostingStrategy,
+      seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
     val algo = boostingStrategy.treeStrategy.algo
     algo match {
       case OldAlgo.Regression =>
-        GradientBoostedTrees.boost(input, input, boostingStrategy, validate = 
false)
+        GradientBoostedTrees.boost(input, input, boostingStrategy, validate = 
false, seed)
       case OldAlgo.Classification =>
         // Map labels to -1, +1 so binary classification can be treated as 
regression.
         val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, 
x.features))
-        GradientBoostedTrees.boost(remappedInput, remappedInput, 
boostingStrategy, validate = false)
+        GradientBoostedTrees.boost(remappedInput, remappedInput, 
boostingStrategy, validate = false,
+          seed)
       case _ =>
         throw new IllegalArgumentException(s"$algo is not supported by 
gradient boosting.")
     }
@@ -61,18 +64,19 @@ private[ml] object GradientBoostedTrees extends Logging {
    *                        but it should follow the same distribution.
    *                        E.g., these two datasets could be created from an 
original dataset
    *                        by using [[org.apache.spark.rdd.RDD.randomSplit()]]
+   * @param seed Random seed.
    * @return tuple of ensemble models and weights:
    *         (array of decision tree models, array of model weights)
    */
   def runWithValidation(
       input: RDD[LabeledPoint],
       validationInput: RDD[LabeledPoint],
-      boostingStrategy: OldBoostingStrategy
-      ): (Array[DecisionTreeRegressionModel], Array[Double]) = {
+      boostingStrategy: OldBoostingStrategy,
+      seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
     val algo = boostingStrategy.treeStrategy.algo
     algo match {
       case OldAlgo.Regression =>
-        GradientBoostedTrees.boost(input, validationInput, boostingStrategy, 
validate = true)
+        GradientBoostedTrees.boost(input, validationInput, boostingStrategy, 
validate = true, seed)
       case OldAlgo.Classification =>
         // Map labels to -1, +1 so binary classification can be treated as 
regression.
         val remappedInput = input.map(
@@ -80,7 +84,7 @@ private[ml] object GradientBoostedTrees extends Logging {
         val remappedValidationInput = validationInput.map(
           x => new LabeledPoint((x.label * 2) - 1, x.features))
         GradientBoostedTrees.boost(remappedInput, remappedValidationInput, 
boostingStrategy,
-          validate = true)
+          validate = true, seed)
       case _ =>
         throw new IllegalArgumentException(s"$algo is not supported by the 
gradient boosting.")
     }
@@ -142,6 +146,7 @@ private[ml] object GradientBoostedTrees extends Logging {
    * @param validationInput validation dataset, ignored if validate is set to 
false.
    * @param boostingStrategy boosting parameters
    * @param validate whether or not to use the validation dataset.
+   * @param seed Random seed.
    * @return tuple of ensemble models and weights:
    *         (array of decision tree models, array of model weights)
    */
@@ -149,7 +154,8 @@ private[ml] object GradientBoostedTrees extends Logging {
       input: RDD[LabeledPoint],
       validationInput: RDD[LabeledPoint],
       boostingStrategy: OldBoostingStrategy,
-      validate: Boolean): (Array[DecisionTreeRegressionModel], Array[Double]) 
= {
+      validate: Boolean,
+      seed: Long): (Array[DecisionTreeRegressionModel], Array[Double]) = {
     val timer = new TimeTracker()
     timer.start("total")
     timer.start("init")
@@ -191,7 +197,7 @@ private[ml] object GradientBoostedTrees extends Logging {
 
     // Initialize tree
     timer.start("building tree 0")
-    val firstTree = new DecisionTreeRegressor()
+    val firstTree = new DecisionTreeRegressor().setSeed(seed)
     val firstTreeModel = firstTree.train(input, treeStrategy)
     val firstTreeWeight = 1.0
     baseLearners(0) = firstTreeModel
@@ -223,7 +229,7 @@ private[ml] object GradientBoostedTrees extends Logging {
       logDebug("###################################################")
       logDebug("Gradient boosting tree iteration " + m)
       logDebug("###################################################")
-      val dt = new DecisionTreeRegressor()
+      val dt = new DecisionTreeRegressor().setSeed(seed + m)
       val model = dt.train(data, treeStrategy)
       timer.stop(s"building tree $m")
       // Update partial model

http://git-wip-us.apache.org/repos/asf/spark/blob/69bc2c17/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 8f02e09..c40d5e3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -43,11 +43,20 @@ import org.apache.spark.util.random.XORShiftRandom
  * @param strategy The configuration parameters for the tree algorithm which 
specify the type
  *                 of decision tree (classification or regression), feature 
type (continuous,
  *                 categorical), depth of the tree, quantile calculation 
strategy, etc.
+ * @param seed Random seed.
  */
 @Since("1.0.0")
-class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
+class DecisionTree private[spark] (private val strategy: Strategy, private val 
seed: Int)
   extends Serializable with Logging {
 
+  /**
+   * @param strategy The configuration parameters for the tree algorithm which 
specify the type
+   *                 of decision tree (classification or regression), feature 
type (continuous,
+   *                 categorical), depth of the tree, quantile calculation 
strategy, etc.
+   */
+  @Since("1.0.0")
+  def this(strategy: Strategy) = this(strategy, seed = 0)
+
   strategy.assertValid()
 
   /**
@@ -58,8 +67,8 @@ class DecisionTree @Since("1.0.0") (private val strategy: 
Strategy)
    */
   @Since("1.2.0")
   def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
-    // Note: random seed will not be used since numTrees = 1.
-    val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = 
"all", seed = 0)
+    val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = 
"all",
+      seed = seed)
     val rfModel = rf.run(input)
     rfModel.trees(0)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/69bc2c17/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index eb40fb0..d166dc7 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -47,12 +47,21 @@ import org.apache.spark.storage.StorageLevel
  *       for other loss functions.
  *
  * @param boostingStrategy Parameters for the gradient boosting algorithm.
+ * @param seed Random seed.
  */
 @Since("1.2.0")
-class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: 
BoostingStrategy)
+class GradientBoostedTrees private[spark] (
+    private val boostingStrategy: BoostingStrategy,
+    private val seed: Int)
   extends Serializable with Logging {
 
   /**
+   * @param boostingStrategy Parameters for the gradient boosting algorithm.
+   */
+  @Since("1.2.0")
+  def this(boostingStrategy: BoostingStrategy) = this(boostingStrategy, seed = 
0)
+
+  /**
    * Method to train a gradient boosting model
    *
    * @param input Training dataset: RDD of 
[[org.apache.spark.mllib.regression.LabeledPoint]].
@@ -63,11 +72,12 @@ class GradientBoostedTrees @Since("1.2.0") (private val 
boostingStrategy: Boosti
     val algo = boostingStrategy.treeStrategy.algo
     algo match {
       case Regression =>
-        GradientBoostedTrees.boost(input, input, boostingStrategy, validate = 
false)
+        GradientBoostedTrees.boost(input, input, boostingStrategy, validate = 
false, seed)
       case Classification =>
         // Map labels to -1, +1 so binary classification can be treated as 
regression.
         val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, 
x.features))
-        GradientBoostedTrees.boost(remappedInput, remappedInput, 
boostingStrategy, validate = false)
+        GradientBoostedTrees.boost(remappedInput, remappedInput, 
boostingStrategy, validate = false,
+          seed)
       case _ =>
         throw new IllegalArgumentException(s"$algo is not supported by the 
gradient boosting.")
     }
@@ -99,7 +109,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val 
boostingStrategy: Boosti
     val algo = boostingStrategy.treeStrategy.algo
     algo match {
       case Regression =>
-        GradientBoostedTrees.boost(input, validationInput, boostingStrategy, 
validate = true)
+        GradientBoostedTrees.boost(input, validationInput, boostingStrategy, 
validate = true, seed)
       case Classification =>
         // Map labels to -1, +1 so binary classification can be treated as 
regression.
         val remappedInput = input.map(
@@ -107,7 +117,7 @@ class GradientBoostedTrees @Since("1.2.0") (private val 
boostingStrategy: Boosti
         val remappedValidationInput = validationInput.map(
           x => new LabeledPoint((x.label * 2) - 1, x.features))
         GradientBoostedTrees.boost(remappedInput, remappedValidationInput, 
boostingStrategy,
-          validate = true)
+          validate = true, seed)
       case _ =>
         throw new IllegalArgumentException(s"$algo is not supported by the 
gradient boosting.")
     }
@@ -140,7 +150,7 @@ object GradientBoostedTrees extends Logging {
   def train(
       input: RDD[LabeledPoint],
       boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
-    new GradientBoostedTrees(boostingStrategy).run(input)
+    new GradientBoostedTrees(boostingStrategy, seed = 0).run(input)
   }
 
   /**
@@ -159,13 +169,15 @@ object GradientBoostedTrees extends Logging {
    * @param validationInput Validation dataset, ignored if validate is set to 
false.
    * @param boostingStrategy Boosting parameters.
    * @param validate Whether or not to use the validation dataset.
+   * @param seed Random seed.
    * @return GradientBoostedTreesModel that can be used for prediction.
    */
   private def boost(
       input: RDD[LabeledPoint],
       validationInput: RDD[LabeledPoint],
       boostingStrategy: BoostingStrategy,
-      validate: Boolean): GradientBoostedTreesModel = {
+      validate: Boolean,
+      seed: Int): GradientBoostedTreesModel = {
     val timer = new TimeTracker()
     timer.start("total")
     timer.start("init")
@@ -207,7 +219,7 @@ object GradientBoostedTrees extends Logging {
 
     // Initialize tree
     timer.start("building tree 0")
-    val firstTreeModel = new DecisionTree(treeStrategy).run(input)
+    val firstTreeModel = new DecisionTree(treeStrategy, seed).run(input)
     val firstTreeWeight = 1.0
     baseLearners(0) = firstTreeModel
     baseLearnerWeights(0) = firstTreeWeight
@@ -238,7 +250,7 @@ object GradientBoostedTrees extends Logging {
       logDebug("###################################################")
       logDebug("Gradient boosting tree iteration " + m)
       logDebug("###################################################")
-      val model = new DecisionTree(treeStrategy).run(data)
+      val model = new DecisionTree(treeStrategy, seed + m).run(data)
       timer.stop(s"building tree $m")
       // Update partial model
       baseLearners(m) = model

http://git-wip-us.apache.org/repos/asf/spark/blob/69bc2c17/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 29efd67..f3680ed 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -74,6 +74,7 @@ class GBTClassifierSuite extends SparkFunSuite with 
MLlibTestSparkContext {
           .setLossType("logistic")
           .setMaxIter(maxIter)
           .setStepSize(learningRate)
+          .setSeed(123)
         compareAPIs(data, None, gbt, categoricalFeatures)
     }
   }
@@ -91,6 +92,7 @@ class GBTClassifierSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       .setMaxIter(5)
       .setStepSize(0.1)
       .setCheckpointInterval(2)
+      .setSeed(123)
     val model = gbt.fit(df)
 
     // copied model must have the same parent.
@@ -159,7 +161,7 @@ private object GBTClassifierSuite extends SparkFunSuite {
     val numFeatures = data.first().features.size
     val oldBoostingStrategy =
       gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
-    val oldGBT = new OldGBT(oldBoostingStrategy)
+    val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
     val oldModel = oldGBT.run(data)
     val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses = 2)
     val newModel = gbt.fit(newData)

http://git-wip-us.apache.org/repos/asf/spark/blob/69bc2c17/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index db68606..84148a8 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -65,6 +65,7 @@ class GBTRegressorSuite extends SparkFunSuite with 
MLlibTestSparkContext {
             .setLossType(loss)
             .setMaxIter(maxIter)
             .setStepSize(learningRate)
+            .setSeed(123)
           compareAPIs(data, None, gbt, categoricalFeatures)
       }
     }
@@ -104,6 +105,7 @@ class GBTRegressorSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       .setMaxIter(5)
       .setStepSize(0.1)
       .setCheckpointInterval(2)
+      .setSeed(123)
     val model = gbt.fit(df)
 
     sc.checkpointDir = None
@@ -169,7 +171,7 @@ private object GBTRegressorSuite extends SparkFunSuite {
       categoricalFeatures: Map[Int, Int]): Unit = {
     val numFeatures = data.first().features.size
     val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, 
OldAlgo.Regression)
-    val oldGBT = new OldGBT(oldBoostingStrategy)
+    val oldGBT = new OldGBT(oldBoostingStrategy, gbt.getSeed.toInt)
     val oldModel = oldGBT.run(data)
     val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 
numClasses = 0)
     val newModel = gbt.fit(newData)

http://git-wip-us.apache.org/repos/asf/spark/blob/69bc2c17/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 58828b3..747c267 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -171,13 +171,13 @@ class GradientBoostedTreesSuite extends SparkFunSuite 
with MLlibTestSparkContext
         categoricalFeaturesInfo = Map.empty)
       val boostingStrategy =
         new BoostingStrategy(treeStrategy, loss, numIterations, validationTol 
= 0.0)
-      val gbtValidate = new GradientBoostedTrees(boostingStrategy)
+      val gbtValidate = new GradientBoostedTrees(boostingStrategy, seed = 0)
         .runWithValidation(trainRdd, validateRdd)
       val numTrees = gbtValidate.numTrees
       assert(numTrees !== numIterations)
 
       // Test that it performs better on the validation dataset.
-      val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
+      val gbt = new GradientBoostedTrees(boostingStrategy, seed = 
0).run(trainRdd)
       val (errorWithoutValidation, errorWithValidation) = {
         if (algo == Classification) {
           val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label 
- 1, x.features))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to