[SPARK-4486][MLLIB] Improve GradientBoosting APIs and doc There are some inconsistencies in the gradient boosting APIs. The target is a general boosting meta-algorithm, but the implementation is attached to trees. This was partially due to the delay of SPARK-1856. But for the 1.2 release, we should make the APIs consistent.
1. WeightedEnsembleModel -> private[tree] TreeEnsembleModel and renamed members accordingly. 1. GradientBoosting -> GradientBoostedTrees 1. Add RandomForestModel and GradientBoostedTreesModel and hide CombiningStrategy 1. Slightly refactored TreeEnsembleModel (Vote takes weights into consideration.) 1. Remove `trainClassifier` and `trainRegressor` from `GradientBoostedTrees` because they are the same as `train` 1. Rename class `train` method to `run` because it hides the static methods with the same name in Java. Deprecated `DecisionTree.train` class method. 1. Simplify BoostingStrategy and make sure the input strategy is not modified. Users should put algo and numClasses in treeStrategy. We create ensembleStrategy inside boosting. 1. Fix a bug in GradientBoostedTreesSuite with AbsoluteError 1. doc updates manishamde jkbradley Author: Xiangrui Meng <[email protected]> Closes #3374 from mengxr/SPARK-4486 and squashes the following commits: 7097251 [Xiangrui Meng] address joseph's comments 98dea09 [Xiangrui Meng] address manish's comments 4aae3b7 [Xiangrui Meng] add RandomForestModel and GradientBoostedTreesModel, hide CombiningStrategy ea4c467 [Xiangrui Meng] fix unit tests 751da4e [Xiangrui Meng] rename class method train -> run 19030a5 [Xiangrui Meng] update boosting public APIs (cherry picked from commit 15cacc81240eed8834b4730c5c6dc3238f003465) Signed-off-by: Xiangrui Meng <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e958132a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e958132a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e958132a Branch: refs/heads/branch-1.2 Commit: e958132a80d202b70976632a51c7e8e4b58d9c4e Parents: 83d24ef Author: Xiangrui Meng <[email protected]> Authored: Thu Nov 20 00:48:59 2014 -0800 Committer: Xiangrui Meng <[email protected]> Committed: Thu Nov 20 00:49:11 2014 -0800 ---------------------------------------------------------------------- .../mllib/JavaGradientBoostedTrees.java | 126 ---------- .../mllib/JavaGradientBoostedTreesRunner.java | 126 ++++++++++ .../examples/mllib/DecisionTreeRunner.scala | 18 +- .../examples/mllib/GradientBoostedTrees.scala | 146 ----------- .../mllib/GradientBoostedTreesRunner.scala | 146 +++++++++++ .../apache/spark/mllib/tree/DecisionTree.scala | 20 +- .../spark/mllib/tree/GradientBoostedTrees.scala | 192 ++++++++++++++ .../spark/mllib/tree/GradientBoosting.scala | 249 ------------------- .../apache/spark/mllib/tree/RandomForest.scala | 40 ++- .../tree/configuration/BoostingStrategy.scala | 50 ++-- .../EnsembleCombiningStrategy.scala | 8 +- .../mllib/tree/configuration/Strategy.scala | 7 + .../spark/mllib/tree/loss/AbsoluteError.scala | 6 +- .../apache/spark/mllib/tree/loss/LogLoss.scala | 6 +- .../org/apache/spark/mllib/tree/loss/Loss.scala | 6 +- .../spark/mllib/tree/loss/SquaredError.scala | 6 +- .../mllib/tree/model/DecisionTreeModel.scala | 4 +- .../tree/model/WeightedEnsembleModel.scala | 158 ------------ .../mllib/tree/model/treeEnsembleModels.scala | 178 +++++++++++++ .../spark/mllib/tree/JavaDecisionTreeSuite.java | 2 +- .../spark/mllib/tree/EnsembleTestHelper.scala | 30 ++- .../mllib/tree/GradientBoostedTreesSuite.scala | 117 +++++++++ .../mllib/tree/GradientBoostingSuite.scala | 126 ---------- .../spark/mllib/tree/RandomForestSuite.scala | 14 +- 24 files changed, 863 insertions(+), 918 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java deleted file mode 100644 index 1af2067..0000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTrees.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.GradientBoosting; -import org.apache.spark.mllib.tree.configuration.BoostingStrategy; -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel; -import org.apache.spark.mllib.util.MLUtils; - -/** - * Classification and regression using gradient-boosted decision trees. - */ -public final class JavaGradientBoostedTrees { - - private static void usage() { - System.err.println("Usage: JavaGradientBoostedTrees <libsvm format data file>" + - " <Classification/Regression>"); - System.exit(-1); - } - - public static void main(String[] args) { - String datapath = "data/mllib/sample_libsvm_data.txt"; - String algo = "Classification"; - if (args.length >= 1) { - datapath = args[0]; - } - if (args.length >= 2) { - algo = args[1]; - } - if (args.length > 2) { - usage(); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTrees"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); - - // Set parameters. - // Note: All features are treated as continuous. - BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); - boostingStrategy.setNumIterations(10); - boostingStrategy.weakLearnerParams().setMaxDepth(5); - - if (algo.equals("Classification")) { - // Compute the number of classes from the data. - Integer numClasses = data.map(new Function<LabeledPoint, Double>() { - @Override public Double call(LabeledPoint p) { - return p.label(); - } - }).countByValue().size(); - boostingStrategy.setNumClassesForClassification(numClasses); // ignored for Regression - - // Train a GradientBoosting model for classification. - final WeightedEnsembleModel model = GradientBoosting.trainClassifier(data, boostingStrategy); - - // Evaluate model on training instances and compute training error - JavaPairRDD<Double, Double> predictionAndLabel = - data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { - @Override public Tuple2<Double, Double> call(LabeledPoint p) { - return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); - } - }); - Double trainErr = - 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() { - @Override public Boolean call(Tuple2<Double, Double> pl) { - return !pl._1().equals(pl._2()); - } - }).count() / data.count(); - System.out.println("Training error: " + trainErr); - System.out.println("Learned classification tree model:\n" + model); - } else if (algo.equals("Regression")) { - // Train a GradientBoosting model for classification. - final WeightedEnsembleModel model = GradientBoosting.trainRegressor(data, boostingStrategy); - - // Evaluate model on training instances and compute training error - JavaPairRDD<Double, Double> predictionAndLabel = - data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { - @Override public Tuple2<Double, Double> call(LabeledPoint p) { - return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); - } - }); - Double trainMSE = - predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() { - @Override public Double call(Tuple2<Double, Double> pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2<Double, Double, Double>() { - @Override public Double call(Double a, Double b) { - return a + b; - } - }) / data.count(); - System.out.println("Training Mean Squared Error: " + trainMSE); - System.out.println("Learned regression tree model:\n" + model); - } else { - usage(); - } - - sc.stop(); - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java new file mode 100644 index 0000000..4a5ac40 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoostedTrees; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; +import org.apache.spark.mllib.util.MLUtils; + +/** + * Classification and regression using gradient-boosted decision trees. + */ +public final class JavaGradientBoostedTreesRunner { + + private static void usage() { + System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file>" + + " <Classification/Regression>"); + System.exit(-1); + } + + public static void main(String[] args) { + String datapath = "data/mllib/sample_libsvm_data.txt"; + String algo = "Classification"; + if (args.length >= 1) { + datapath = args[0]; + } + if (args.length >= 2) { + algo = args[1]; + } + if (args.length > 2) { + usage(); + } + SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + + // Set parameters. + // Note: All features are treated as continuous. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); + boostingStrategy.setNumIterations(10); + boostingStrategy.treeStrategy().setMaxDepth(5); + + if (algo.equals("Classification")) { + // Compute the number of classes from the data. + Integer numClasses = data.map(new Function<LabeledPoint, Double>() { + @Override public Double call(LabeledPoint p) { + return p.label(); + } + }).countByValue().size(); + boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses); + + // Train a GradientBoosting model for classification. + final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD<Double, Double> predictionAndLabel = + data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { + @Override public Tuple2<Double, Double> call(LabeledPoint p) { + return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); + } + }); + Double trainErr = + 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() { + @Override public Boolean call(Tuple2<Double, Double> pl) { + return !pl._1().equals(pl._2()); + } + }).count() / data.count(); + System.out.println("Training error: " + trainErr); + System.out.println("Learned classification tree model:\n" + model); + } else if (algo.equals("Regression")) { + // Train a GradientBoosting model for classification. + final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD<Double, Double> predictionAndLabel = + data.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { + @Override public Tuple2<Double, Double> call(LabeledPoint p) { + return new Tuple2<Double, Double>(model.predict(p.features()), p.label()); + } + }); + Double trainMSE = + predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() { + @Override public Double call(Tuple2<Double, Double> pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2<Double, Double, Double>() { + @Override public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Training Mean Squared Error: " + trainMSE); + System.out.println("Learned regression tree model:\n" + model); + } else { + usage(); + } + + sc.stop(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 63f02cf..98f9d16 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -22,11 +22,11 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity} +import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -352,21 +352,11 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. */ - private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = tree.predict(y.features) - y.label - err * err - }.mean() - } - - /** - * Calculates the mean squared error for regression. - */ private[mllib] def meanSquaredError( - tree: WeightedEnsembleModel, + model: { def predict(features: Vector): Double }, data: RDD[LabeledPoint]): Double = { data.map { y => - val err = tree.predict(y.features) - y.label + val err = model.predict(y.features) - y.label err * err }.mean() } http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala deleted file mode 100644 index 9b6db01..0000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTrees.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib - -import scopt.OptionParser - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.mllib.tree.GradientBoosting -import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} -import org.apache.spark.util.Utils - -/** - * An example runner for Gradient Boosting using decision trees as weak learners. Run with - * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.GradientBoostedTrees [options] - * }}} - * If you use it as a template to create your own app, please use `spark-submit` to submit your app. - * - * Note: This script treats all features as real-valued (not categorical). - * To include categorical features, modify categoricalFeaturesInfo. - */ -object GradientBoostedTrees { - - case class Params( - input: String = null, - testInput: String = "", - dataFormat: String = "libsvm", - algo: String = "Classification", - maxDepth: Int = 5, - numIterations: Int = 10, - fracTest: Double = 0.2) extends AbstractParams[Params] - - def main(args: Array[String]) { - val defaultParams = Params() - - val parser = new OptionParser[Params]("GradientBoostedTrees") { - head("GradientBoostedTrees: an example decision tree app.") - opt[String]("algo") - .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") - .action((x, c) => c.copy(algo = x)) - opt[Int]("maxDepth") - .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") - .action((x, c) => c.copy(maxDepth = x)) - opt[Int]("numIterations") - .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}") - .action((x, c) => c.copy(numIterations = x)) - opt[Double]("fracTest") - .text(s"fraction of data to hold out for testing. If given option testInput, " + - s"this option is ignored. default: ${defaultParams.fracTest}") - .action((x, c) => c.copy(fracTest = x)) - opt[String]("testInput") - .text(s"input path to test dataset. If given, option fracTest is ignored." + - s" default: ${defaultParams.testInput}") - .action((x, c) => c.copy(testInput = x)) - opt[String]("<dataFormat>") - .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") - .action((x, c) => c.copy(dataFormat = x)) - arg[String]("<input>") - .text("input path to labeled examples") - .required() - .action((x, c) => c.copy(input = x)) - checkConfig { params => - if (params.fracTest < 0 || params.fracTest > 1) { - failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") - } else { - success - } - } - } - - parser.parse(args, defaultParams).map { params => - run(params) - }.getOrElse { - sys.exit(1) - } - } - - def run(params: Params) { - - val conf = new SparkConf().setAppName(s"GradientBoostedTrees with $params") - val sc = new SparkContext(conf) - - println(s"GradientBoostedTrees with parameters:\n$params") - - // Load training and test data and cache it. - val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, - params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) - - val boostingStrategy = BoostingStrategy.defaultParams(params.algo) - boostingStrategy.numClassesForClassification = numClasses - boostingStrategy.numIterations = params.numIterations - boostingStrategy.weakLearnerParams.maxDepth = params.maxDepth - - val randomSeed = Utils.random.nextInt() - if (params.algo == "Classification") { - val startTime = System.nanoTime() - val model = GradientBoosting.trainClassifier(training, boostingStrategy) - val elapsedTime = (System.nanoTime() - startTime) / 1e9 - println(s"Training time: $elapsedTime seconds") - if (model.totalNumNodes < 30) { - println(model.toDebugString) // Print full model. - } else { - println(model) // Print model summary. - } - val trainAccuracy = - new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) - .precision - println(s"Train accuracy = $trainAccuracy") - val testAccuracy = - new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision - println(s"Test accuracy = $testAccuracy") - } else if (params.algo == "Regression") { - val startTime = System.nanoTime() - val model = GradientBoosting.trainRegressor(training, boostingStrategy) - val elapsedTime = (System.nanoTime() - startTime) / 1e9 - println(s"Training time: $elapsedTime seconds") - if (model.totalNumNodes < 30) { - println(model.toDebugString) // Print full model. - } else { - println(model) // Print model summary. - } - val trainMSE = DecisionTreeRunner.meanSquaredError(model, training) - println(s"Train mean squared error = $trainMSE") - val testMSE = DecisionTreeRunner.meanSquaredError(model, test) - println(s"Test mean squared error = $testMSE") - } - - sc.stop() - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala new file mode 100644 index 0000000..1def8b4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.tree.GradientBoostedTrees +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} +import org.apache.spark.util.Utils + +/** + * An example runner for Gradient Boosting using decision trees as weak learners. Run with + * {{{ + * ./bin/run-example mllib.GradientBoostedTreesRunner [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. + */ +object GradientBoostedTreesRunner { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + numIterations: Int = 10, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("GradientBoostedTrees") { + head("GradientBoostedTrees: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("numIterations") + .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("<dataFormat>") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("<input>") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params") + val sc = new SparkContext(conf) + + println(s"GradientBoostedTreesRunner with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) + + val boostingStrategy = BoostingStrategy.defaultParams(params.algo) + boostingStrategy.treeStrategy.numClassesForClassification = numClasses + boostingStrategy.numIterations = params.numIterations + boostingStrategy.treeStrategy.maxDepth = params.maxDepth + + val randomSeed = Utils.random.nextInt() + if (params.algo == "Classification") { + val startTime = System.nanoTime() + val model = GradientBoostedTrees.train(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainAccuracy = + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) + .precision + println(s"Train accuracy = $trainAccuracy") + val testAccuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $testAccuracy") + } else if (params.algo == "Regression") { + val startTime = System.nanoTime() + val model = GradientBoostedTrees.train(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainMSE = DecisionTreeRunner.meanSquaredError(model, training) + println(s"Train mean squared error = $trainMSE") + val testMSE = DecisionTreeRunner.meanSquaredError(model, test) + println(s"Test mean squared error = $testMSE") + } + + sc.stop() + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 78acc17..3d91867 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -58,13 +58,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @return DecisionTreeModel that can be used for prediction */ - def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + def run(input: RDD[LabeledPoint]): DecisionTreeModel = { // Note: random seed will not be used since numTrees = 1. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) - val rfModel = rf.train(input) - rfModel.weakHypotheses(0) + val rfModel = rf.run(input) + rfModel.trees(0) } + /** + * Trains a decision tree model over an RDD. This is deprecated because it hides the static + * methods with the same name in Java. + */ + @deprecated("Please use DecisionTree.run instead.", "1.2.0") + def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input) } object DecisionTree extends Serializable with Logging { @@ -86,7 +92,7 @@ object DecisionTree extends Serializable with Logging { * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -112,7 +118,7 @@ object DecisionTree extends Serializable with Logging { impurity: Impurity, maxDepth: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -140,7 +146,7 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, numClassesForClassification: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -177,7 +183,7 @@ object DecisionTree extends Serializable with Logging { categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala new file mode 100644 index 0000000..cb4ddfc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.impl.TimeTracker +import org.apache.spark.mllib.tree.impurity.Variance +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * :: Experimental :: + * A class that implements Stochastic Gradient Boosting for regression and binary classification. + * + * The implementation is based upon: + * J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes: + * - This currently can be run with several loss functions. However, only SquaredError is + * fully supported. Specifically, the loss function should be used to compute the gradient + * (to re-label training instances on each iteration) and to weight weak hypotheses. + * Currently, gradients are computed correctly for the available loss functions, + * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. + * Running with those losses will likely behave reasonably, but lacks the same guarantees. + * + * @param boostingStrategy Parameters for the gradient boosting algorithm. + */ +@Experimental +class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) + extends Serializable with Logging { + + /** + * Method to train a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return a gradient boosted trees model that can be used for prediction + */ + def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case Regression => GradientBoostedTrees.boost(input, boostingStrategy) + case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, boostingStrategy) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]]. + */ + def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { + run(input.rdd) + } +} + + +object GradientBoostedTrees extends Logging { + + /** + * Method to train a gradient boosting model. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param boostingStrategy Configuration options for the boosting algorithm. + * @return a gradient boosted trees model that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + new GradientBoostedTrees(boostingStrategy).run(input) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]] + */ + def train( + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + train(input.rdd, boostingStrategy) + } + + /** + * Internal method for performing regression using trees as base learners. + * @param input training dataset + * @param boostingStrategy boosting parameters + * @return a gradient boosted trees model that can be used for prediction + */ + private def boost( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + + val timer = new TimeTracker() + timer.start("total") + timer.start("init") + + boostingStrategy.assertValid() + + // Initialize gradient boosting parameters + val numIterations = boostingStrategy.numIterations + val baseLearners = new Array[DecisionTreeModel](numIterations) + val baseLearnerWeights = new Array[Double](numIterations) + val loss = boostingStrategy.loss + val learningRate = boostingStrategy.learningRate + // Prepare strategy for individual trees, which use regression with variance impurity. + val treeStrategy = boostingStrategy.treeStrategy.copy + treeStrategy.algo = Regression + treeStrategy.impurity = Variance + treeStrategy.assertValid() + + // Cache input + if (input.getStorageLevel == StorageLevel.NONE) { + input.persist(StorageLevel.MEMORY_AND_DISK) + } + + timer.stop("init") + + logDebug("##########") + logDebug("Building tree 0") + logDebug("##########") + var data = input + + // Initialize tree + timer.start("building tree 0") + val firstTreeModel = new DecisionTree(treeStrategy).run(data) + baseLearners(0) = firstTreeModel + baseLearnerWeights(0) = 1.0 + val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) + logDebug("error of gbt = " + loss.computeError(startingModel, input)) + // Note: A model of type regression is used since we require raw prediction + timer.stop("building tree 0") + + // psuedo-residual for second iteration + data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), + point.features)) + + var m = 1 + while (m < numIterations) { + timer.start(s"building tree $m") + logDebug("###################################################") + logDebug("Gradient boosting tree iteration " + m) + logDebug("###################################################") + val model = new DecisionTree(treeStrategy).run(data) + timer.stop(s"building tree $m") + // Create partial model + baseLearners(m) = model + // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. + // Technically, the weight should be optimized for the particular loss. + // However, the behavior should be reasonable, though not optimal. + baseLearnerWeights(m) = learningRate + // Note: A model of type regression is used since we require raw prediction + val partialModel = new GradientBoostedTreesModel( + Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) + logDebug("error of gbt = " + loss.computeError(partialModel, input)) + // Update data with pseudo-residuals + data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), + point.features)) + m += 1 + } + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala deleted file mode 100644 index f729344..0000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala +++ /dev/null @@ -1,249 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree - -import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum -import org.apache.spark.mllib.tree.impl.TimeTracker -import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel} -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - -/** - * :: Experimental :: - * A class that implements Stochastic Gradient Boosting - * for regression and binary classification problems. - * - * The implementation is based upon: - * J.H. Friedman. "Stochastic Gradient Boosting." 1999. - * - * Notes: - * - This currently can be run with several loss functions. However, only SquaredError is - * fully supported. Specifically, the loss function should be used to compute the gradient - * (to re-label training instances on each iteration) and to weight weak hypotheses. - * Currently, gradients are computed correctly for the available loss functions, - * but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError. - * Running with those losses will likely behave reasonably, but lacks the same guarantees. - * - * @param boostingStrategy Parameters for the gradient boosting algorithm - */ -@Experimental -class GradientBoosting ( - private val boostingStrategy: BoostingStrategy) extends Serializable with Logging { - - boostingStrategy.weakLearnerParams.algo = Regression - boostingStrategy.weakLearnerParams.impurity = impurity.Variance - - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - boostingStrategy.weakLearnerParams.numClassesForClassification = - boostingStrategy.numClassesForClassification - - boostingStrategy.assertValid() - - /** - * Method to train a gradient boosting model - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return WeightedEnsembleModel that can be used for prediction - */ - def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { - val algo = boostingStrategy.algo - algo match { - case Regression => GradientBoosting.boost(input, boostingStrategy) - case Classification => - // Map labels to -1, +1 so binary classification can be treated as regression. - val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoosting.boost(remappedInput, boostingStrategy) - case _ => - throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") - } - } - -} - - -object GradientBoosting extends Logging { - - /** - * Method to train a gradient boosting model. - * - * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] - * is recommended to clearly specify regression. - * Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] - * is recommended to clearly specify regression. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param boostingStrategy Configuration options for the boosting algorithm. - * @return WeightedEnsembleModel that can be used for prediction - */ - def train( - input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Method to train a gradient boosting classification model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param boostingStrategy Configuration options for the boosting algorithm. - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainClassifier( - input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - val algo = boostingStrategy.algo - require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.") - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Method to train a gradient boosting regression model. - * - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * For classification, labels should take values {0, 1, ..., numClasses-1}. - * For regression, labels are real numbers. - * @param boostingStrategy Configuration options for the boosting algorithm. - * @return WeightedEnsembleModel that can be used for prediction - */ - def trainRegressor( - input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - val algo = boostingStrategy.algo - require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.") - new GradientBoosting(boostingStrategy).train(input) - } - - /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#train]] - */ - def train( - input: JavaRDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - train(input.rdd, boostingStrategy) - } - - /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]] - */ - def trainClassifier( - input: JavaRDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - trainClassifier(input.rdd, boostingStrategy) - } - - /** - * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]] - */ - def trainRegressor( - input: JavaRDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - trainRegressor(input.rdd, boostingStrategy) - } - - /** - * Internal method for performing regression using trees as base learners. - * @param input training dataset - * @param boostingStrategy boosting parameters - * @return - */ - private def boost( - input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): WeightedEnsembleModel = { - - val timer = new TimeTracker() - timer.start("total") - timer.start("init") - - // Initialize gradient boosting parameters - val numIterations = boostingStrategy.numIterations - val baseLearners = new Array[DecisionTreeModel](numIterations) - val baseLearnerWeights = new Array[Double](numIterations) - val loss = boostingStrategy.loss - val learningRate = boostingStrategy.learningRate - val strategy = boostingStrategy.weakLearnerParams - - // Cache input - if (input.getStorageLevel == StorageLevel.NONE) { - input.persist(StorageLevel.MEMORY_AND_DISK) - } - - timer.stop("init") - - logDebug("##########") - logDebug("Building tree 0") - logDebug("##########") - var data = input - - // Initialize tree - timer.start("building tree 0") - val firstTreeModel = new DecisionTree(strategy).train(data) - baseLearners(0) = firstTreeModel - baseLearnerWeights(0) = 1.0 - val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression, - Sum) - logDebug("error of gbt = " + loss.computeError(startingModel, input)) - // Note: A model of type regression is used since we require raw prediction - timer.stop("building tree 0") - - // psuedo-residual for second iteration - data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), - point.features)) - - var m = 1 - while (m < numIterations) { - timer.start(s"building tree $m") - logDebug("###################################################") - logDebug("Gradient boosting tree iteration " + m) - logDebug("###################################################") - val model = new DecisionTree(strategy).train(data) - timer.stop(s"building tree $m") - // Create partial model - baseLearners(m) = model - // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. - // Technically, the weight should be optimized for the particular loss. - // However, the behavior should be reasonable, though not optimal. - baseLearnerWeights(m) = learningRate - // Note: A model of type regression is used since we require raw prediction - val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1), - baseLearnerWeights.slice(0, m + 1), Regression, Sum) - logDebug("error of gbt = " + loss.computeError(partialModel, input)) - // Update data with pseudo-residuals - data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), - point.features)) - m += 1 - } - - timer.stop("total") - - logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") - - new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum) - - } - -} http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 9683916..ca0b6ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -17,18 +17,18 @@ package org.apache.spark.mllib.tree -import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.collection.JavaConverters._ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average -import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache } +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache, + TimeTracker, TreePoint} import org.apache.spark.mllib.tree.impurity.Impurities import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD @@ -79,9 +79,9 @@ private class RandomForest ( /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ - def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = { + def run(input: RDD[LabeledPoint]): RandomForestModel = { val timer = new TimeTracker() @@ -212,8 +212,7 @@ private class RandomForest ( } val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) - val treeWeights = Array.fill[Double](numTrees)(1.0) - new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average) + new RandomForestModel(strategy.algo, trees) } } @@ -234,18 +233,18 @@ object RandomForest extends Serializable with Logging { * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { require(strategy.algo == Classification, s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) - rf.train(input) + rf.run(input) } /** @@ -272,7 +271,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], @@ -283,7 +282,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = { + seed: Int = Utils.random.nextInt()): RandomForestModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo) @@ -302,7 +301,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { trainClassifier(input.rdd, numClassesForClassification, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) @@ -322,18 +321,18 @@ object RandomForest extends Serializable with Logging { * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], strategy: Strategy, numTrees: Int, featureSubsetStrategy: String, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { require(strategy.algo == Regression, s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) - rf.train(input) + rf.run(input) } /** @@ -359,7 +358,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return WeightedEnsembleModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], @@ -369,7 +368,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = { + seed: Int = Utils.random.nextInt()): RandomForestModel = { val impurityType = Impurities.fromString(impurity) val strategy = new Strategy(Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo) @@ -387,7 +386,7 @@ object RandomForest extends Serializable with Logging { impurity: String, maxDepth: Int, maxBins: Int, - seed: Int): WeightedEnsembleModel = { + seed: Int): RandomForestModel = { trainRegressor(input.rdd, categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) @@ -479,5 +478,4 @@ object RandomForest extends Serializable with Logging { 3 * totalBins } } - } http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index abbda04..e703adb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -25,57 +25,39 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} /** * :: Experimental :: - * Stores all the configuration options for the boosting algorithms - * @param algo Learning goal. Supported: - * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], - * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]]. + * + * @param treeStrategy Parameters for the tree algorithm. We support regression and binary + * classification for boosting. Impurity setting will be ignored. + * @param loss Loss function used for minimization during gradient boosting. * @param numIterations Number of iterations of boosting. In other words, the number of * weak hypotheses used in the final model. - * @param loss Loss function used for minimization during gradient boosting. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] - * @param numClassesForClassification Number of classes for classification. - * (Ignored for regression.) - * This setting overrides any setting in [[weakLearnerParams]]. - * Default value is 2 (binary classification). - * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are - * supported. */ @Experimental case class BoostingStrategy( // Required boosting parameters - @BeanProperty var algo: Algo, - @BeanProperty var numIterations: Int, + @BeanProperty var treeStrategy: Strategy, @BeanProperty var loss: Loss, // Optional boosting parameters - @BeanProperty var learningRate: Double = 0.1, - @BeanProperty var numClassesForClassification: Int = 2, - @BeanProperty var weakLearnerParams: Strategy) extends Serializable { - - // Ensure values for weak learner are the same as what is provided to the boosting algorithm. - weakLearnerParams.numClassesForClassification = numClassesForClassification - - /** - * Sets Algorithm using a String. - */ - def setAlgo(algo: String): Unit = algo match { - case "Classification" => setAlgo(Classification) - case "Regression" => setAlgo(Regression) - } + @BeanProperty var numIterations: Int = 100, + @BeanProperty var learningRate: Double = 0.1) extends Serializable { /** * Check validity of parameters. * Throws exception if invalid. */ private[tree] def assertValid(): Unit = { - algo match { + treeStrategy.algo match { case Classification => - require(numClassesForClassification == 2) + require(treeStrategy.numClassesForClassification == 2, + "Only binary classification is supported for boosting.") case Regression => // nothing case _ => throw new IllegalArgumentException( - s"BoostingStrategy given invalid algo parameter: $algo." + + s"BoostingStrategy given invalid algo parameter: ${treeStrategy.algo}." + s" Valid settings are: Classification, Regression.") } require(learningRate > 0 && learningRate <= 1, @@ -94,14 +76,14 @@ object BoostingStrategy { * @return Configuration for boosting algorithm */ def defaultParams(algo: String): BoostingStrategy = { - val treeStrategy = Strategy.defaultStrategy("Regression") + val treeStrategy = Strategy.defaultStrategy(algo) treeStrategy.maxDepth = 3 algo match { case "Classification" => - new BoostingStrategy(Algo.withName(algo), 100, LogLoss, weakLearnerParams = treeStrategy) + treeStrategy.numClassesForClassification = 2 + new BoostingStrategy(treeStrategy, LogLoss) case "Regression" => - new BoostingStrategy(Algo.withName(algo), 100, SquaredError, - weakLearnerParams = treeStrategy) + new BoostingStrategy(treeStrategy, SquaredError) case _ => throw new IllegalArgumentException(s"$algo is not supported by the boosting.") } http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala index 82889dc..b5bf732 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala @@ -17,14 +17,10 @@ package org.apache.spark.mllib.tree.configuration -import org.apache.spark.annotation.DeveloperApi - /** - * :: Experimental :: * Enum to select ensemble combining strategy for base learners */ -@DeveloperApi -object EnsembleCombiningStrategy extends Enumeration { +private[tree] object EnsembleCombiningStrategy extends Enumeration { type EnsembleCombiningStrategy = Value - val Sum, Average = Value + val Average, Sum, Vote = Value } http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index b5b1f82..d75f384 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -157,6 +157,13 @@ class Strategy ( require(maxMemoryInMB <= 10240, s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") } + + /** Returns a shallow copy of this instance. */ + def copy: Strategy = { + new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, + quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, + maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval) + } } @Experimental http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala index d111ffe..e828866 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.SparkContext._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** @@ -42,7 +42,7 @@ object AbsoluteError extends Loss { * @return Loss gradient */ override def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double = { if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 } @@ -55,7 +55,7 @@ object AbsoluteError extends Loss { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return */ - override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { val sumOfAbsolutes = data.map { y => val err = model.predict(y.features) - y.label math.abs(err) http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala index 6f3d434..8b8adb4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** @@ -42,7 +42,7 @@ object LogLoss extends Loss { * @return Loss gradient */ override def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double = { val prediction = model.predict(point.features) 1.0 / (1.0 + math.exp(-prediction)) - point.label @@ -56,7 +56,7 @@ object LogLoss extends Loss { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return */ - override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count() wrongPredictions / data.count } http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 5580866..4bca903 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** @@ -36,7 +36,7 @@ trait Loss extends Serializable { * @return Loss gradient. */ def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double /** @@ -47,6 +47,6 @@ trait Loss extends Serializable { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return */ - def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double + def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double } http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala index 4349fef..cfe395b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss import org.apache.spark.SparkContext._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.WeightedEnsembleModel +import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.rdd.RDD /** @@ -43,7 +43,7 @@ object SquaredError extends Loss { * @return Loss gradient */ override def gradient( - model: WeightedEnsembleModel, + model: TreeEnsembleModel, point: LabeledPoint): Double = { model.predict(point.features) - point.label } @@ -56,7 +56,7 @@ object SquaredError extends Loss { * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return */ - override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = { + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { data.map { y => val err = model.predict(y.features) - y.label err * err http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index ac4d02e..a576096 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -17,11 +17,11 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.api.java.JavaRDD import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.Vector /** * :: Experimental :: http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala deleted file mode 100644 index 7b052d9..0000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ -import org.apache.spark.rdd.RDD - -import scala.collection.mutable - -@Experimental -class WeightedEnsembleModel( - val weakHypotheses: Array[DecisionTreeModel], - val weakHypothesisWeights: Array[Double], - val algo: Algo, - val combiningStrategy: EnsembleCombiningStrategy) extends Serializable { - - require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" + - s". Number of weakHypotheses = $weakHypotheses") - - /** - * Predict values for a single data point using the model trained. - * - * @param features array representing a single data point - * @return predicted category from the trained model - */ - private def predictRaw(features: Vector): Double = { - val treePredictions = weakHypotheses.map(learner => learner.predict(features)) - if (numWeakHypotheses == 1){ - treePredictions(0) - } else { - var prediction = treePredictions(0) - var index = 1 - while (index < numWeakHypotheses) { - prediction += weakHypothesisWeights(index) * treePredictions(index) - index += 1 - } - prediction - } - } - - /** - * Predict values for a single data point using the model trained. - * - * @param features array representing a single data point - * @return predicted category from the trained model - */ - private def predictBySumming(features: Vector): Double = { - algo match { - case Regression => predictRaw(features) - case Classification => { - // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. - if (predictRaw(features) > 0 ) 1.0 else 0.0 - } - case _ => throw new IllegalArgumentException( - s"WeightedEnsembleModel given unknown algo parameter: $algo.") - } - } - - /** - * Predict values for a single data point. - * - * @param features array representing a single data point - * @return Double prediction from the trained model - */ - private def predictByAveraging(features: Vector): Double = { - algo match { - case Classification => - val predictionToCount = new mutable.HashMap[Int, Int]() - weakHypotheses.foreach { learner => - val prediction = learner.predict(features).toInt - predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1 - } - predictionToCount.maxBy(_._2)._1 - case Regression => - weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size - } - } - - - /** - * Predict values for a single data point using the model trained. - * - * @param features array representing a single data point - * @return predicted category from the trained model - */ - def predict(features: Vector): Double = { - combiningStrategy match { - case Sum => predictBySumming(features) - case Average => predictByAveraging(features) - case _ => throw new IllegalArgumentException( - s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.") - } - } - - /** - * Predict values for the given data set. - * - * @param features RDD representing data points to be predicted - * @return RDD[Double] where each entry contains the corresponding prediction - */ - def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x)) - - /** - * Print a summary of the model. - */ - override def toString: String = { - algo match { - case Classification => - s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n" - case Regression => - s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n" - case _ => throw new IllegalArgumentException( - s"WeightedEnsembleModel given unknown algo parameter: $algo.") - } - } - - /** - * Print the full model to a string. - */ - def toDebugString: String = { - val header = toString + "\n" - header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) => - s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) - }.fold("")(_ + _) - } - - /** - * Get number of trees in forest. - */ - def numWeakHypotheses: Int = weakHypotheses.size - - // TODO: Remove these helpers methods once class is generalized to support any base learning - // algorithms. - - /** - * Get total number of nodes, summed over all trees in the forest. - */ - def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum - -} http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala new file mode 100644 index 0000000..2299711 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import scala.collection.mutable + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Represents a random forest model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + */ +@Experimental +class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) + extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0), + combiningStrategy = if (algo == Classification) Vote else Average) { + + require(trees.forall(_.algo == algo)) +} + +/** + * :: Experimental :: + * Represents a gradient boosted trees model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + * @param treeWeights tree ensemble weights + */ +@Experimental +class GradientBoostedTreesModel( + override val algo: Algo, + override val trees: Array[DecisionTreeModel], + override val treeWeights: Array[Double]) + extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) { + + require(trees.size == treeWeights.size) +} + +/** + * Represents a tree ensemble model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + * @param treeWeights tree ensemble weights + * @param combiningStrategy strategy for combining the predictions, not used for regression. + */ +private[tree] sealed class TreeEnsembleModel( + protected val algo: Algo, + protected val trees: Array[DecisionTreeModel], + protected val treeWeights: Array[Double], + protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable { + + require(numTrees > 0, "TreeEnsembleModel cannot be created without trees.") + + private val sumWeights = math.max(treeWeights.sum, 1e-15) + + /** + * Predicts for a single data point using the weighted sum of ensemble predictions. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + private def predictBySumming(features: Vector): Double = { + val treePredictions = trees.map(_.predict(features)) + blas.ddot(numTrees, treePredictions, 1, treeWeights, 1) + } + + /** + * Classifies a single data point based on (weighted) majority votes. + */ + private def predictByVoting(features: Vector): Double = { + val votes = mutable.Map.empty[Int, Double] + trees.view.zip(treeWeights).foreach { case (tree, weight) => + val prediction = tree.predict(features).toInt + votes(prediction) = votes.getOrElse(prediction, 0.0) + weight + } + votes.maxBy(_._2)._1 + } + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + def predict(features: Vector): Double = { + (algo, combiningStrategy) match { + case (Regression, Sum) => + predictBySumming(features) + case (Regression, Average) => + predictBySumming(features) / sumWeights + case (Classification, Sum) => // binary classification + val prediction = predictBySumming(features) + // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. + if (prediction > 0.0) 1.0 else 0.0 + case (Classification, Vote) => + predictByVoting(features) + case _ => + throw new IllegalArgumentException( + "TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: " + + s"($algo, $combiningStrategy).") + } + } + + /** + * Predict values for the given data set. + * + * @param features RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x)) + + /** + * Java-friendly version of [[org.apache.spark.mllib.tree.model.TreeEnsembleModel#predict]]. + */ + def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = { + predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] + } + + /** + * Print a summary of the model. + */ + override def toString: String = { + algo match { + case Classification => + s"TreeEnsembleModel classifier with $numTrees trees\n" + case Regression => + s"TreeEnsembleModel regressor with $numTrees trees\n" + case _ => throw new IllegalArgumentException( + s"TreeEnsembleModel given unknown algo parameter: $algo.") + } + } + + /** + * Print the full model to a string. + */ + def toDebugString: String = { + val header = toString + "\n" + header + trees.zipWithIndex.map { case (tree, treeIndex) => + s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) + }.fold("")(_ + _) + } + + /** + * Get number of trees in forest. + */ + def numTrees: Int = trees.size + + /** + * Get total number of nodes, summed over all trees in the forest. + */ + def totalNumNodes: Int = trees.map(_.numNodes).sum +} http://git-wip-us.apache.org/repos/asf/spark/blob/e958132a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 2c281a1..9925aae 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -74,7 +74,7 @@ public class JavaDecisionTreeSuite implements Serializable { maxBins, categoricalFeaturesInfo); DecisionTree learner = new DecisionTree(strategy); - DecisionTreeModel model = learner.train(rdd.rdd()); + DecisionTreeModel model = learner.run(rdd.rdd()); int numCorrect = validatePrediction(arr, model); Assert.assertTrue(numCorrect == rdd.count()); --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
