Repository: spark Updated Branches: refs/heads/branch-1.4 80b0fe200 -> bc355e243
[SPARK-8736] [ML] GBTRegressor should not threshold prediction Changed GBTRegressor so it does NOT threshold the prediction. Added test which fails with bug but works after fix. CC: feynmanliang mengxr Author: Joseph K. Bradley <jos...@databricks.com> Closes #7134 from jkbradley/gbrt-fix and squashes the following commits: 613b90e [Joseph K. Bradley] Changed GBTRegressor so it does NOT threshold the prediction (cherry picked from commit 3ba23ffd377d12383d923d1550ac8e2b916090fc) Signed-off-by: Xiangrui Meng <m...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc355e24 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc355e24 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc355e24 Branch: refs/heads/branch-1.4 Commit: bc355e24368123baca5335ddf5560ded1da11141 Parents: 80b0fe2 Author: Joseph K. Bradley <jos...@databricks.com> Authored: Tue Jun 30 14:02:50 2015 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Tue Jun 30 14:02:57 2015 -0700 ---------------------------------------------------------------------- .../spark/ml/regression/GBTRegressor.scala | 3 +-- .../spark/ml/regression/GBTRegressorSuite.scala | 23 +++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/bc355e24/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 036e3ac..47c110d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -172,8 +172,7 @@ final class GBTRegressionModel( // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predict(features)) - val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) - if (prediction > 0.0) 1.0 else 0.0 + blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } override def copy(extra: ParamMap): GBTRegressionModel = { http://git-wip-us.apache.org/repos/asf/spark/blob/bc355e24/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 98fb3d3..9682edc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** @@ -67,6 +68,26 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("GBTRegressor behaves reasonably on toy data") { + val df = sqlContext.createDataFrame(Seq( + LabeledPoint(10, Vectors.dense(1, 2, 3, 4)), + LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)), + LabeledPoint(11, Vectors.dense(2, 2, 3, 4)), + LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)), + LabeledPoint(9, Vectors.dense(1, 2, 6, 4)), + LabeledPoint(-4, Vectors.dense(6, 3, 2, 2)) + )) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(2) + val model = gbt.fit(df) + val preds = model.transform(df) + val predictions = preds.select("prediction").map(_.getDouble(0)) + // Checks based on SPARK-8736 (to ensure it is not doing classification) + assert(predictions.max() > 2) + assert(predictions.min() < -1) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org