[SPARK-5995] [ML] Make Prediction dev API public Changes: * Update protected prediction methods, following design doc. **<--most interesting change** * Changed abstract classes for Estimator and Model to be public. Added DeveloperApi tag. (I kept the traits for Estimator/Model Params private.) * Changed ProbabilisticClassificationModel method names to use probability instead of probabilities.
CC: mengxr shivaram etrain Author: Joseph K. Bradley <[email protected]> Closes #5913 from jkbradley/public-dev-api and squashes the following commits: e9aa0ea [Joseph K. Bradley] moved findMax to DenseVector and renamed to argmax. fixed bug for vector of length 0 15b9957 [Joseph K. Bradley] renamed probabilities to probability in method names 5cda84d [Joseph K. Bradley] regenerated sharedParams 7d1877a [Joseph K. Bradley] Made spark.ml prediction abstractions public. Organized their prediction methods for efficient computation of multiple output columns. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1ad04dae Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1ad04dae Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1ad04dae Branch: refs/heads/master Commit: 1ad04dae038673a448f529c39b17817b78d6acd0 Parents: 7740996 Author: Joseph K. Bradley <[email protected]> Authored: Wed May 6 16:15:51 2015 -0700 Committer: Xiangrui Meng <[email protected]> Committed: Wed May 6 16:15:51 2015 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/Predictor.scala | 191 ++++++++ .../spark/ml/classification/Classifier.scala | 110 ++--- .../classification/DecisionTreeClassifier.scala | 5 +- .../spark/ml/classification/GBTClassifier.scala | 5 +- .../ml/classification/LogisticRegression.scala | 100 ++--- .../ProbabilisticClassifier.scala | 100 +++-- .../classification/RandomForestClassifier.scala | 5 +- .../spark/ml/impl/estimator/Predictor.scala | 217 ---------- .../apache/spark/ml/impl/tree/treeParams.scala | 431 ------------------- .../ml/param/shared/SharedParamsCodeGen.scala | 6 +- .../spark/ml/param/shared/sharedParams.scala | 4 +- .../ml/regression/DecisionTreeRegressor.scala | 5 +- .../spark/ml/regression/GBTRegressor.scala | 5 +- .../spark/ml/regression/LinearRegression.scala | 5 +- .../ml/regression/RandomForestRegressor.scala | 5 +- .../apache/spark/ml/regression/Regressor.scala | 42 +- .../org/apache/spark/ml/tree/treeParams.scala | 431 +++++++++++++++++++ .../org/apache/spark/mllib/linalg/Vectors.scala | 22 + 18 files changed, 814 insertions(+), 875 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala new file mode 100644 index 0000000..0e53877 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} +import org.apache.spark.sql.{DataFrame, Row} + +/** + * (private[ml]) Trait for parameters for prediction (regression and classification). + */ +private[ml] trait PredictorParams extends Params + with HasLabelCol with HasFeaturesCol with HasPredictionCol { + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param fitting whether this is in fitting + * @param featuresDataType SQL DataType for FeaturesType. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { + // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector + SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) + if (fitting) { + // TODO: Allow other numeric types + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + } + SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) + } +} + +/** + * :: DeveloperApi :: + * + * Abstraction for prediction problems (regression and classification). + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam Learner Specialization of this class. If you subclass this type, use this type + * parameter to specify the concrete type. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + */ +@DeveloperApi +abstract class Predictor[ + FeaturesType, + Learner <: Predictor[FeaturesType, Learner, M], + M <: PredictionModel[FeaturesType, M]] + extends Estimator[M] with PredictorParams { + + /** @group setParam */ + def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner] + + /** @group setParam */ + def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner] + + /** @group setParam */ + def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] + + override def fit(dataset: DataFrame): M = { + // This handles a few items such as schema validation. + // Developers only need to implement train(). + transformSchema(dataset.schema, logging = true) + copyValues(train(dataset)) + } + + override def copy(extra: ParamMap): Learner = { + super.copy(extra).asInstanceOf[Learner] + } + + /** + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @return Fitted model + */ + protected def train(dataset: DataFrame): M + + /** + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by [[validateAndTransformSchema()]]. + * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + protected def featuresDataType: DataType = new VectorUDT + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true, featuresDataType) + } + + /** + * Extract [[labelCol]] and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + */ + protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { + dataset.select($(labelCol), $(featuresCol)) + .map { case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + } + } +} + +/** + * :: DeveloperApi :: + * + * Abstraction for a model for prediction tasks (regression and classification). + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + */ +@DeveloperApi +abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]] + extends Model[M] with PredictorParams { + + /** @group setParam */ + def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M] + + /** @group setParam */ + def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] + + /** + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by [[validateAndTransformSchema()]]. + * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + protected def featuresDataType: DataType = new VectorUDT + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = false, featuresDataType) + } + + /** + * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing + * the predictions as a new column [[predictionCol]]. + * + * @param dataset input dataset + * @return transformed dataset with [[predictionCol]] of type [[Double]] + */ + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + if ($(predictionCol).nonEmpty) { + dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol)))) + } else { + this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + + " since no output columns were set.") + dataset + } + } + + /** + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + protected def predict(features: FeaturesType): Double +} http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d3361e2..263d580 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -26,15 +26,12 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + /** - * :: DeveloperApi :: - * Params for classification. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + * (private[spark]) Params for classification. */ -@DeveloperApi -private[spark] trait ClassifierParams extends PredictorParams - with HasRawPredictionCol { +private[spark] trait ClassifierParams + extends PredictorParams with HasRawPredictionCol { override protected def validateAndTransformSchema( schema: StructType, @@ -46,23 +43,21 @@ private[spark] trait ClassifierParams extends PredictorParams } /** - * :: AlphaComponent :: + * :: DeveloperApi :: + * * Single-label binary or multiclass classification. * Classes are indexed {0, 1, ..., numClasses - 1}. * * @tparam FeaturesType Type of input features. E.g., [[Vector]] * @tparam E Concrete Estimator type * @tparam M Concrete Model type - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent -private[spark] abstract class Classifier[ +@DeveloperApi +abstract class Classifier[ FeaturesType, E <: Classifier[FeaturesType, E, M], M <: ClassificationModel[FeaturesType, M]] - extends Predictor[FeaturesType, E, M] - with ClassifierParams { + extends Predictor[FeaturesType, E, M] with ClassifierParams { /** @group setParam */ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] @@ -71,17 +66,15 @@ private[spark] abstract class Classifier[ } /** - * :: AlphaComponent :: + * :: DeveloperApi :: + * * Model produced by a [[Classifier]]. * Classes are indexed {0, 1, ..., numClasses - 1}. * * @tparam FeaturesType Type of input features. E.g., [[Vector]] * @tparam M Concrete Model type - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent -private[spark] +@DeveloperApi abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]] extends PredictionModel[FeaturesType, M] with ClassifierParams { @@ -101,13 +94,27 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * @return transformed dataset */ override def transform(dataset: DataFrame): DataFrame = { - // This default implementation should be overridden as needed. - - // Check schema transformSchema(dataset.schema, logging = true) - val (numColsOutput, outputData) = - ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this) + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + var outputData = dataset + var numColsOutput = 0 + if (getRawPredictionCol != "") { + outputData = outputData.withColumn(getRawPredictionCol, + callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol))) + numColsOutput += 1 + } + if (getPredictionCol != "") { + val predUDF = if (getRawPredictionCol != "") { + callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol)) + } else { + callUDF(predict _, DoubleType, col(getFeaturesCol)) + } + outputData = outputData.withColumn(getPredictionCol, predUDF) + numColsOutput += 1 + } + if (numColsOutput == 0) { logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + " since no output columns were set.") @@ -116,22 +123,17 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur } /** - * :: DeveloperApi :: - * * Predict label for the given features. * This internal method is used to implement [[transform()]] and output [[predictionCol]]. * * This default implementation for classification predicts the index of the maximum value * from [[predictRaw()]]. */ - @DeveloperApi override protected def predict(features: FeaturesType): Double = { - predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2 + raw2prediction(predictRaw(features)) } /** - * :: DeveloperApi :: - * * Raw prediction for each possible label. * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives * a measure of confidence in each possible label (where larger = more confident). @@ -141,48 +143,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * This raw prediction may be any real number, where a larger value indicates greater * confidence for that label. */ - @DeveloperApi protected def predictRaw(features: FeaturesType): Vector -} - -private[ml] object ClassificationModel { /** - * Added prediction column(s). This is separated from [[ClassificationModel.transform()]] - * since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]]. - * @param dataset Input dataset - * @return (number of columns added, transformed dataset) + * Given a vector of raw predictions, select the predicted label. + * This may be overridden to support thresholds which favor particular labels. + * @return predicted label */ - def transformColumnsImpl[FeaturesType]( - dataset: DataFrame, - model: ClassificationModel[FeaturesType, _]): (Int, DataFrame) = { - - // Output selected columns only. - // This is a bit complicated since it tries to avoid repeated computation. - var tmpData = dataset - var numColsOutput = 0 - if (model.getRawPredictionCol != "") { - // output raw prediction - val features2raw: FeaturesType => Vector = model.predictRaw - tmpData = tmpData.withColumn(model.getRawPredictionCol, - callUDF(features2raw, new VectorUDT, col(model.getFeaturesCol))) - numColsOutput += 1 - if (model.getPredictionCol != "") { - val raw2pred: Vector => Double = (rawPred) => { - rawPred.toArray.zipWithIndex.maxBy(_._1)._2 - } - tmpData = tmpData.withColumn(model.getPredictionCol, - callUDF(raw2pred, DoubleType, col(model.getRawPredictionCol))) - numColsOutput += 1 - } - } else if (model.getPredictionCol != "") { - // output prediction - val features2pred: FeaturesType => Double = model.predict - tmpData = tmpData.withColumn(model.getPredictionCol, - callUDF(features2pred, DoubleType, col(model.getFeaturesCol))) - numColsOutput += 1 - } - (numColsOutput, tmpData) - } - + protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax } http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 419e5ba..dcebea1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -18,10 +18,9 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} -import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, Node} +import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 534ea95..ae51b05 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -21,11 +21,10 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} -import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index b73be03..550369d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -21,9 +21,8 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions._ import org.apache.spark.storage.StorageLevel /** @@ -99,76 +98,17 @@ class LogisticRegressionModel private[ml] ( /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) + /** Margin (rawPrediction) for class label 1. For binary classification only. */ private val margin: Vector => Double = (features) => { BLAS.dot(features, weights) + intercept } + /** Score (probability) for class label 1. For binary classification only. */ private val score: Vector => Double = (features) => { val m = margin(features) 1.0 / (1.0 + math.exp(-m)) } - override def transform(dataset: DataFrame): DataFrame = { - // This is overridden (a) to be more efficient (avoiding re-computing values when creating - // multiple output columns) and (b) to handle threshold, which the abstractions do not use. - // TODO: We should abstract away the steps defined by UDFs below so that the abstractions - // can call whichever UDFs are needed to create the output columns. - - // Check schema - transformSchema(dataset.schema, logging = true) - - // Output selected columns only. - // This is a bit complicated since it tries to avoid repeated computation. - // rawPrediction (-margin, margin) - // probability (1.0-score, score) - // prediction (max margin) - var tmpData = dataset - var numColsOutput = 0 - if ($(rawPredictionCol) != "") { - val features2raw: Vector => Vector = (features) => predictRaw(features) - tmpData = tmpData.withColumn($(rawPredictionCol), - callUDF(features2raw, new VectorUDT, col($(featuresCol)))) - numColsOutput += 1 - } - if ($(probabilityCol) != "") { - if ($(rawPredictionCol) != "") { - val raw2prob = udf { (rawPreds: Vector) => - val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) - Vectors.dense(1.0 - prob1, prob1): Vector - } - tmpData = tmpData.withColumn($(probabilityCol), raw2prob(col($(rawPredictionCol)))) - } else { - val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector } - tmpData = tmpData.withColumn($(probabilityCol), features2prob(col($(featuresCol)))) - } - numColsOutput += 1 - } - if ($(predictionCol) != "") { - val t = $(threshold) - if ($(probabilityCol) != "") { - val predict = udf { probs: Vector => - if (probs(1) > t) 1.0 else 0.0 - } - tmpData = tmpData.withColumn($(predictionCol), predict(col($(probabilityCol)))) - } else if ($(rawPredictionCol) != "") { - val predict = udf { rawPreds: Vector => - val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) - if (prob1 > t) 1.0 else 0.0 - } - tmpData = tmpData.withColumn($(predictionCol), predict(col($(rawPredictionCol)))) - } else { - val predict = udf { features: Vector => this.predict(features) } - tmpData = tmpData.withColumn($(predictionCol), predict(col($(featuresCol)))) - } - numColsOutput += 1 - } - if (numColsOutput == 0) { - this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" + - " since no output columns were set.") - } - tmpData - } - override val numClasses: Int = 2 /** @@ -179,17 +119,43 @@ class LogisticRegressionModel private[ml] ( if (score(features) > getThreshold) 1 else 0 } - override protected def predictProbabilities(features: Vector): Vector = { - val s = score(features) - Vectors.dense(1.0 - s, s) + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + while (i < dv.size) { + dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i))) + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in LogisticRegressionModel:" + + " raw2probabilitiesInPlace encountered SparseVector") + } } override protected def predictRaw(features: Vector): Vector = { val m = margin(features) - Vectors.dense(0.0, m) + Vectors.dense(-m, m) } override def copy(extra: ParamMap): LogisticRegressionModel = { copyValues(new LogisticRegressionModel(parent, weights, intercept), extra) } + + override protected def raw2prediction(rawPrediction: Vector): Double = { + val t = getThreshold + val rawThreshold = if (t == 0.0) { + Double.NegativeInfinity + } else if (t == 1.0) { + Double.PositiveInfinity + } else { + Math.log(t / (1.0 - t)) + } + if (rawPrediction(1) > rawThreshold) 1 else 0 + } + + override protected def probability2prediction(probability: Vector): Double = { + if (probability(1) > getThreshold) 1 else 0 + } } http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 8519841..330ae29 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DoubleType, DataType, StructType} /** - * Params for probabilistic classification. + * (private[classification]) Params for probabilistic classification. */ private[classification] trait ProbabilisticClassifierParams extends ClassifierParams with HasProbabilityCol { @@ -42,17 +42,15 @@ private[classification] trait ProbabilisticClassifierParams /** - * :: AlphaComponent :: + * :: DeveloperApi :: * * Single-label binary or multiclass classifier which can output class conditional probabilities. * * @tparam FeaturesType Type of input features. E.g., [[Vector]] * @tparam E Concrete Estimator type * @tparam M Concrete Model type - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent +@DeveloperApi private[spark] abstract class ProbabilisticClassifier[ FeaturesType, E <: ProbabilisticClassifier[FeaturesType, E, M], @@ -65,17 +63,15 @@ private[spark] abstract class ProbabilisticClassifier[ /** - * :: AlphaComponent :: + * :: DeveloperApi :: * * Model produced by a [[ProbabilisticClassifier]]. * Classes are indexed {0, 1, ..., numClasses - 1}. * * @tparam FeaturesType Type of input features. E.g., [[Vector]] * @tparam M Concrete Model type - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent +@DeveloperApi private[spark] abstract class ProbabilisticClassificationModel[ FeaturesType, M <: ProbabilisticClassificationModel[FeaturesType, M]] @@ -95,39 +91,79 @@ private[spark] abstract class ProbabilisticClassificationModel[ * @return transformed dataset */ override def transform(dataset: DataFrame): DataFrame = { - // This default implementation should be overridden as needed. - - // Check schema transformSchema(dataset.schema, logging = true) - val (numColsOutput, outputData) = - ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this) - // Output selected columns only. - if ($(probabilityCol) != "") { - // output probabilities - outputData.withColumn($(probabilityCol), - callUDF(predictProbabilities _, new VectorUDT, col($(featuresCol)))) - } else { - if (numColsOutput == 0) { - this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + - " since no output columns were set.") + // This is a bit complicated since it tries to avoid repeated computation. + var outputData = dataset + var numColsOutput = 0 + if ($(rawPredictionCol).nonEmpty) { + outputData = outputData.withColumn(getRawPredictionCol, + callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol))) + numColsOutput += 1 + } + if ($(probabilityCol).nonEmpty) { + val probUDF = if ($(rawPredictionCol).nonEmpty) { + callUDF(raw2probability _, new VectorUDT, col($(rawPredictionCol))) + } else { + callUDF(predictProbability _, new VectorUDT, col($(featuresCol))) + } + outputData = outputData.withColumn($(probabilityCol), probUDF) + numColsOutput += 1 + } + if ($(predictionCol).nonEmpty) { + val predUDF = if ($(rawPredictionCol).nonEmpty) { + callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol))) + } else if ($(probabilityCol).nonEmpty) { + callUDF(probability2prediction _, DoubleType, col($(probabilityCol))) + } else { + callUDF(predict _, DoubleType, col($(featuresCol))) } - outputData + outputData = outputData.withColumn($(predictionCol), predUDF) + numColsOutput += 1 + } + + if (numColsOutput == 0) { + this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") } + outputData } /** - * :: DeveloperApi :: + * Estimate the probability of each class given the raw prediction, + * doing the computation in-place. + * These predictions are also called class conditional probabilities. + * + * This internal method is used to implement [[transform()]] and output [[probabilityCol]]. * + * @return Estimated class conditional probabilities (modified input vector) + */ + protected def raw2probabilityInPlace(rawPrediction: Vector): Vector + + /** Non-in-place version of [[raw2probabilityInPlace()]] */ + protected def raw2probability(rawPrediction: Vector): Vector = { + val probs = rawPrediction.copy + raw2probabilityInPlace(probs) + } + + /** * Predict the probability of each class given the features. * These predictions are also called class conditional probabilities. * - * WARNING: Not all models output well-calibrated probability estimates! These probabilities - * should be treated as confidences, not precise probabilities. - * * This internal method is used to implement [[transform()]] and output [[probabilityCol]]. + * + * @return Estimated class conditional probabilities + */ + protected def predictProbability(features: FeaturesType): Vector = { + val rawPreds = predictRaw(features) + raw2probabilityInPlace(rawPreds) + } + + /** + * Given a vector of class conditional probabilities, select the predicted label. + * This may be overridden to support thresholds which favor particular labels. + * @return predicted label */ - @DeveloperApi - protected def predictProbabilities(features: FeaturesType): Vector + protected def probability2prediction(probability: Vector): Double = probability.toDense.argmax } http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 17f59bb..9954893 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -20,10 +20,9 @@ package org.apache.spark.ml.classification import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} -import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala deleted file mode 100644 index e8b3628..0000000 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.impl.estimator - -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DataType, DoubleType, StructType} - -/** - * :: DeveloperApi :: - * - * Trait for parameters for prediction (regression and classification). - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. - */ -@DeveloperApi -private[spark] trait PredictorParams extends Params - with HasLabelCol with HasFeaturesCol with HasPredictionCol { - - /** - * Validates and transforms the input schema with the provided param map. - * @param schema input schema - * @param fitting whether this is in fitting - * @param featuresDataType SQL DataType for FeaturesType. - * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. - * @return output schema - */ - protected def validateAndTransformSchema( - schema: StructType, - fitting: Boolean, - featuresDataType: DataType): StructType = { - // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector - SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) - if (fitting) { - // TODO: Allow other numeric types - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) - } - SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) - } -} - -/** - * :: AlphaComponent :: - * - * Abstraction for prediction problems (regression and classification). - * - * @tparam FeaturesType Type of features. - * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. - * @tparam Learner Specialization of this class. If you subclass this type, use this type - * parameter to specify the concrete type. - * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type - * parameter to specify the concrete type for the corresponding model. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. - */ -@AlphaComponent -private[spark] abstract class Predictor[ - FeaturesType, - Learner <: Predictor[FeaturesType, Learner, M], - M <: PredictionModel[FeaturesType, M]] - extends Estimator[M] with PredictorParams { - - /** @group setParam */ - def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner] - - /** @group setParam */ - def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner] - - /** @group setParam */ - def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] - - override def fit(dataset: DataFrame): M = { - // This handles a few items such as schema validation. - // Developers only need to implement train(). - transformSchema(dataset.schema, logging = true) - copyValues(train(dataset)) - } - - override def copy(extra: ParamMap): Learner = { - super.copy(extra).asInstanceOf[Learner] - } - - /** - * :: DeveloperApi :: - * - * Train a model using the given dataset and parameters. - * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation - * and copying parameters into the model. - * - * @param dataset Training dataset - * @return Fitted model - */ - @DeveloperApi - protected def train(dataset: DataFrame): M - - /** - * :: DeveloperApi :: - * - * Returns the SQL DataType corresponding to the FeaturesType type parameter. - * - * This is used by [[validateAndTransformSchema()]]. - * This workaround is needed since SQL has different APIs for Scala and Java. - * - * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. - */ - @DeveloperApi - protected def featuresDataType: DataType = new VectorUDT - - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema, fitting = true, featuresDataType) - } - - /** - * Extract [[labelCol]] and [[featuresCol]] from the given dataset, - * and put it in an RDD with strong types. - */ - protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { - dataset.select($(labelCol), $(featuresCol)) - .map { case Row(label: Double, features: Vector) => - LabeledPoint(label, features) - } - } -} - -/** - * :: AlphaComponent :: - * - * Abstraction for a model for prediction tasks (regression and classification). - * - * @tparam FeaturesType Type of features. - * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. - * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type - * parameter to specify the concrete type for the corresponding model. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. - */ -@AlphaComponent -private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]] - extends Model[M] with PredictorParams { - - /** @group setParam */ - def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M] - - /** @group setParam */ - def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] - - /** - * :: DeveloperApi :: - * - * Returns the SQL DataType corresponding to the FeaturesType type parameter. - * - * This is used by [[validateAndTransformSchema()]]. - * This workaround is needed since SQL has different APIs for Scala and Java. - * - * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. - */ - @DeveloperApi - protected def featuresDataType: DataType = new VectorUDT - - override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema, fitting = false, featuresDataType) - } - - /** - * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing - * the predictions as a new column [[predictionCol]]. - * - * @param dataset input dataset - * @return transformed dataset with [[predictionCol]] of type [[Double]] - */ - override def transform(dataset: DataFrame): DataFrame = { - // This default implementation should be overridden as needed. - - // Check schema - transformSchema(dataset.schema, logging = true) - - if ($(predictionCol) != "") { - dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol)))) - } else { - this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + - " since no output columns were set.") - dataset - } - } - - /** - * :: DeveloperApi :: - * - * Predict label for the given features. - * This internal method is used to implement [[transform()]] and output [[predictionCol]]. - */ - @DeveloperApi - protected def predict(features: FeaturesType): Double -} http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala deleted file mode 100644 index 0e22562..0000000 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ /dev/null @@ -1,431 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.impl.tree - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.impl.estimator.PredictorParams -import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} -import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} -import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} - -/** - * :: DeveloperApi :: - * Parameters for Decision Tree-based algorithms. - * - * Note: Marked as private and DeveloperApi since this may be made public in the future. - */ -@DeveloperApi -private[ml] trait DecisionTreeParams extends PredictorParams { - - /** - * Maximum depth of the tree (>= 0). - * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. - * (default = 5) - * @group param - */ - final val maxDepth: IntParam = - new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" + - " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.", - ParamValidators.gtEq(0)) - - /** - * Maximum number of bins used for discretizing continuous features and for choosing how to split - * on features at each node. More bins give higher granularity. - * Must be >= 2 and >= number of categories in any categorical feature. - * (default = 32) - * @group param - */ - final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" + - " discretizing continuous features. Must be >=2 and >= number of categories for any" + - " categorical feature.", ParamValidators.gtEq(2)) - - /** - * Minimum number of instances each child must have after split. - * If a split causes the left or right child to have fewer than minInstancesPerNode, - * the split will be discarded as invalid. - * Should be >= 1. - * (default = 1) - * @group param - */ - final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" + - " number of instances each child must have after split. If a split causes the left or right" + - " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." + - " Should be >= 1.", ParamValidators.gtEq(1)) - - /** - * Minimum information gain for a split to be considered at a tree node. - * (default = 0.0) - * @group param - */ - final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain", - "Minimum information gain for a split to be considered at a tree node.") - - /** - * Maximum memory in MB allocated to histogram aggregation. - * (default = 256 MB) - * @group expertParam - */ - final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB", - "Maximum memory in MB allocated to histogram aggregation.", - ParamValidators.gtEq(0)) - - /** - * If false, the algorithm will pass trees to executors to match instances with nodes. - * If true, the algorithm will cache node IDs for each instance. - * Caching can speed up training of deeper trees. - * (default = false) - * @group expertParam - */ - final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" + - " algorithm will pass trees to executors to match instances with nodes. If true, the" + - " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + - " trees.") - - /** - * Specifies how often to checkpoint the cached node IDs. - * E.g. 10 means that the cache will get checkpointed every 10 iterations. - * This is only used if cacheNodeIds is true and if the checkpoint directory is set in - * [[org.apache.spark.SparkContext]]. - * Must be >= 1. - * (default = 10) - * @group expertParam - */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + - " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + - " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + - " checkpoint directory is set in the SparkContext. Must be >= 1.", - ParamValidators.gtEq(1)) - - setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, - maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) - - /** @group setParam */ - def setMaxDepth(value: Int): this.type = set(maxDepth, value) - - /** @group getParam */ - final def getMaxDepth: Int = $(maxDepth) - - /** @group setParam */ - def setMaxBins(value: Int): this.type = set(maxBins, value) - - /** @group getParam */ - final def getMaxBins: Int = $(maxBins) - - /** @group setParam */ - def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) - - /** @group getParam */ - final def getMinInstancesPerNode: Int = $(minInstancesPerNode) - - /** @group setParam */ - def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) - - /** @group getParam */ - final def getMinInfoGain: Double = $(minInfoGain) - - /** @group expertSetParam */ - def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) - - /** @group expertGetParam */ - final def getMaxMemoryInMB: Int = $(maxMemoryInMB) - - /** @group expertSetParam */ - def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) - - /** @group expertGetParam */ - final def getCacheNodeIds: Boolean = $(cacheNodeIds) - - /** @group expertSetParam */ - def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) - - /** @group expertGetParam */ - final def getCheckpointInterval: Int = $(checkpointInterval) - - /** (private[ml]) Create a Strategy instance to use with the old API. */ - private[ml] def getOldStrategy( - categoricalFeatures: Map[Int, Int], - numClasses: Int, - oldAlgo: OldAlgo.Algo, - oldImpurity: OldImpurity, - subsamplingRate: Double): OldStrategy = { - val strategy = OldStrategy.defaultStategy(oldAlgo) - strategy.impurity = oldImpurity - strategy.checkpointInterval = getCheckpointInterval - strategy.maxBins = getMaxBins - strategy.maxDepth = getMaxDepth - strategy.maxMemoryInMB = getMaxMemoryInMB - strategy.minInfoGain = getMinInfoGain - strategy.minInstancesPerNode = getMinInstancesPerNode - strategy.useNodeIdCache = getCacheNodeIds - strategy.numClasses = numClasses - strategy.categoricalFeaturesInfo = categoricalFeatures - strategy.subsamplingRate = subsamplingRate - strategy - } -} - -/** - * Parameters for Decision Tree-based classification algorithms. - */ -private[ml] trait TreeClassifierParams extends Params { - - /** - * Criterion used for information gain calculation (case-insensitive). - * Supported: "entropy" and "gini". - * (default = gini) - * @group param - */ - final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + - " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", - (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) - - setDefault(impurity -> "gini") - - /** @group setParam */ - def setImpurity(value: String): this.type = set(impurity, value) - - /** @group getParam */ - final def getImpurity: String = $(impurity).toLowerCase - - /** Convert new impurity to old impurity. */ - private[ml] def getOldImpurity: OldImpurity = { - getImpurity match { - case "entropy" => OldEntropy - case "gini" => OldGini - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException( - s"TreeClassifierParams was given unrecognized impurity: $impurity.") - } - } -} - -private[ml] object TreeClassifierParams { - // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) -} - -/** - * Parameters for Decision Tree-based regression algorithms. - */ -private[ml] trait TreeRegressorParams extends Params { - - /** - * Criterion used for information gain calculation (case-insensitive). - * Supported: "variance". - * (default = variance) - * @group param - */ - final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + - " information gain calculation (case-insensitive). Supported options:" + - s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", - (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase)) - - setDefault(impurity -> "variance") - - /** @group setParam */ - def setImpurity(value: String): this.type = set(impurity, value) - - /** @group getParam */ - final def getImpurity: String = $(impurity).toLowerCase - - /** Convert new impurity to old impurity. */ - private[ml] def getOldImpurity: OldImpurity = { - getImpurity match { - case "variance" => OldVariance - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException( - s"TreeRegressorParams was given unrecognized impurity: $impurity") - } - } -} - -private[ml] object TreeRegressorParams { - // These options should be lowercase. - final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) -} - -/** - * :: DeveloperApi :: - * Parameters for Decision Tree-based ensemble algorithms. - * - * Note: Marked as private and DeveloperApi since this may be made public in the future. - */ -@DeveloperApi -private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { - - /** - * Fraction of the training data used for learning each decision tree, in range (0, 1]. - * (default = 1.0) - * @group param - */ - final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate", - "Fraction of the training data used for learning each decision tree, in range (0, 1].", - ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) - - setDefault(subsamplingRate -> 1.0) - - /** @group setParam */ - def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) - - /** @group getParam */ - final def getSubsamplingRate: Double = $(subsamplingRate) - - /** @group setParam */ - def setSeed(value: Long): this.type = set(seed, value) - - /** - * Create a Strategy instance to use with the old API. - * NOTE: The caller should set impurity and seed. - */ - private[ml] def getOldStrategy( - categoricalFeatures: Map[Int, Int], - numClasses: Int, - oldAlgo: OldAlgo.Algo, - oldImpurity: OldImpurity): OldStrategy = { - super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) - } -} - -/** - * :: DeveloperApi :: - * Parameters for Random Forest algorithms. - * - * Note: Marked as private and DeveloperApi since this may be made public in the future. - */ -@DeveloperApi -private[ml] trait RandomForestParams extends TreeEnsembleParams { - - /** - * Number of trees to train (>= 1). - * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. - * TODO: Change to always do bootstrapping (simpler). SPARK-7130 - * (default = 20) - * @group param - */ - final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", - ParamValidators.gtEq(1)) - - /** - * The number of features to consider for splits at each tree node. - * Supported options: - * - "auto": Choose automatically for task: - * If numTrees == 1, set to "all." - * If numTrees > 1 (forest), set to "sqrt" for classification and - * to "onethird" for regression. - * - "all": use all features - * - "onethird": use 1/3 of the features - * - "sqrt": use sqrt(number of features) - * - "log2": use log2(number of features) - * (default = "auto") - * - * These various settings are based on the following references: - * - log2: tested in Breiman (2001) - * - sqrt: recommended by Breiman manual for random forests - * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest - * package. - * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]] - * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for - * random forests]] - * - * @group param - */ - final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", - "The number of features to consider for splits at each tree node." + - s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}", - (value: String) => - RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)) - - setDefault(numTrees -> 20, featureSubsetStrategy -> "auto") - - /** @group setParam */ - def setNumTrees(value: Int): this.type = set(numTrees, value) - - /** @group getParam */ - final def getNumTrees: Int = $(numTrees) - - /** @group setParam */ - def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) - - /** @group getParam */ - final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase -} - -private[ml] object RandomForestParams { - // These options should be lowercase. - final val supportedFeatureSubsetStrategies: Array[String] = - Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) -} - -/** - * :: DeveloperApi :: - * Parameters for Gradient-Boosted Tree algorithms. - * - * Note: Marked as private and DeveloperApi since this may be made public in the future. - */ -@DeveloperApi -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { - - /** - * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each - * estimator. - * (default = 0.1) - * @group param - */ - final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + - " learning rate) in interval (0, 1] for shrinking the contribution of each estimator", - ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) - - /* TODO: Add this doc when we add this param. SPARK-7132 - * Threshold for stopping early when runWithValidation is used. - * If the error rate on the validation input changes by less than the validationTol, - * then learning will stop early (before [[numIterations]]). - * This parameter is ignored when run is used. - * (default = 1e-5) - * @group param - */ - // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") - // validationTol -> 1e-5 - - setDefault(maxIter -> 20, stepSize -> 0.1) - - /** @group setParam */ - def setMaxIter(value: Int): this.type = set(maxIter, value) - - /** @group setParam */ - def setStepSize(value: Double): this.type = set(stepSize, value) - - /** @group getParam */ - final def getStepSize: Double = $(stepSize) - - /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ - private[ml] def getOldBoostingStrategy( - categoricalFeatures: Map[Int, Int], - oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { - val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) - // NOTE: The old API does not support "seed" so we ignore it. - new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) - } - - /** Get old Gradient Boosting Loss type */ - private[ml] def getOldLossType: OldLoss -} http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index d379172..0e1ff97 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -40,8 +40,10 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")), ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", Some("\"rawPrediction\"")), - ParamDesc[String]("probabilityCol", - "column name for predicted class conditional probabilities", Some("\"probability\"")), + ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" + + " probabilities. Note: Not all models output well-calibrated probability estimates!" + + " These probabilities should be treated as confidences, not precise probabilities.", + Some("\"probability\"")), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", isValid = "ParamValidators.inRange(0, 1)"), http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index fb1874c..87f8680 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -128,10 +128,10 @@ private[ml] trait HasRawPredictionCol extends Params { private[ml] trait HasProbabilityCol extends Params { /** - * Param for column name for predicted class conditional probabilities. + * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.. * @group param */ - final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities") + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") setDefault(probabilityCol, "probability") http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index b07c26f..f8f0b16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -18,10 +18,9 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} -import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, Node} +import org.apache.spark.ml.tree.{TreeRegressorParams, DecisionTreeParams, DecisionTreeModel, Node} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index bc79695..461905c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -21,10 +21,9 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} -import org.apache.spark.ml.impl.tree._ +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 66c475f..e63c9a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -25,6 +25,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -39,7 +40,7 @@ import org.apache.spark.util.StatCounter /** * Params for linear regression. */ -private[regression] trait LinearRegressionParams extends RegressorParams +private[regression] trait LinearRegressionParams extends PredictorParams with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol /** @@ -240,7 +241,7 @@ class LinearRegressionModel private[ml] ( * + \bar{y} / \hat{y}||^2 * = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2 * }}} - * where w_i^\prime is the effective weights defined by w_i/\hat{x_i}, offset is + * where w_i^\prime^ is the effective weights defined by w_i/\hat{x_i}, offset is * {{{ * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}. * }}}, and diff is http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 0468a1b..dbc6289 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -18,10 +18,9 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor} -import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams} +import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel} +import org.apache.spark.ml.tree.{RandomForestParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala index c6b3327..c72ef29 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala @@ -17,62 +17,40 @@ package org.apache.spark.ml.regression -import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} -/** - * :: DeveloperApi :: - * Params for regression. - * Currently empty, but may add functionality later. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. - */ -@DeveloperApi -private[spark] trait RegressorParams extends PredictorParams /** - * :: AlphaComponent :: + * :: DeveloperApi :: * * Single-label regression * * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] * @tparam Learner Concrete Estimator type * @tparam M Concrete Model type - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent +@DeveloperApi private[spark] abstract class Regressor[ FeaturesType, Learner <: Regressor[FeaturesType, Learner, M], M <: RegressionModel[FeaturesType, M]] - extends Predictor[FeaturesType, Learner, M] - with RegressorParams { + extends Predictor[FeaturesType, Learner, M] with PredictorParams { // TODO: defaultEvaluator (follow-up PR) } /** - * :: AlphaComponent :: + * :: DeveloperApi :: * * Model produced by a [[Regressor]]. * * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]] * @tparam M Concrete Model type. - * - * NOTE: This is currently private[spark] but will be made public later once it is stabilized. */ -@AlphaComponent -private[spark] abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]] - extends PredictionModel[FeaturesType, M] with RegressorParams { - - /** - * :: DeveloperApi :: - * - * Predict real-valued label for the given features. - * This internal method is used to implement [[transform()]] and output [[predictionCol]]. - */ - @DeveloperApi - protected def predict(features: FeaturesType): Double +@DeveloperApi +abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]] + extends PredictionModel[FeaturesType, M] with PredictorParams { + // TODO: defaultEvaluator (follow-up PR) } http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala new file mode 100644 index 0000000..816fced --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} +import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} +import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} + +/** + * :: DeveloperApi :: + * Parameters for Decision Tree-based algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait DecisionTreeParams extends PredictorParams { + + /** + * Maximum depth of the tree (>= 0). + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (default = 5) + * @group param + */ + final val maxDepth: IntParam = + new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" + + " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.", + ParamValidators.gtEq(0)) + + /** + * Maximum number of bins used for discretizing continuous features and for choosing how to split + * on features at each node. More bins give higher granularity. + * Must be >= 2 and >= number of categories in any categorical feature. + * (default = 32) + * @group param + */ + final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" + + " discretizing continuous features. Must be >=2 and >= number of categories for any" + + " categorical feature.", ParamValidators.gtEq(2)) + + /** + * Minimum number of instances each child must have after split. + * If a split causes the left or right child to have fewer than minInstancesPerNode, + * the split will be discarded as invalid. + * Should be >= 1. + * (default = 1) + * @group param + */ + final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" + + " number of instances each child must have after split. If a split causes the left or right" + + " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." + + " Should be >= 1.", ParamValidators.gtEq(1)) + + /** + * Minimum information gain for a split to be considered at a tree node. + * (default = 0.0) + * @group param + */ + final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain", + "Minimum information gain for a split to be considered at a tree node.") + + /** + * Maximum memory in MB allocated to histogram aggregation. + * (default = 256 MB) + * @group expertParam + */ + final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB", + "Maximum memory in MB allocated to histogram aggregation.", + ParamValidators.gtEq(0)) + + /** + * If false, the algorithm will pass trees to executors to match instances with nodes. + * If true, the algorithm will cache node IDs for each instance. + * Caching can speed up training of deeper trees. + * (default = false) + * @group expertParam + */ + final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" + + " algorithm will pass trees to executors to match instances with nodes. If true, the" + + " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" + + " trees.") + + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be >= 1. + * (default = 10) + * @group expertParam + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" + + " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" + + " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" + + " checkpoint directory is set in the SparkContext. Must be >= 1.", + ParamValidators.gtEq(1)) + + setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, + maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10) + + /** @group setParam */ + def setMaxDepth(value: Int): this.type = set(maxDepth, value) + + /** @group getParam */ + final def getMaxDepth: Int = $(maxDepth) + + /** @group setParam */ + def setMaxBins(value: Int): this.type = set(maxBins, value) + + /** @group getParam */ + final def getMaxBins: Int = $(maxBins) + + /** @group setParam */ + def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + + /** @group getParam */ + final def getMinInstancesPerNode: Int = $(minInstancesPerNode) + + /** @group setParam */ + def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + + /** @group getParam */ + final def getMinInfoGain: Double = $(minInfoGain) + + /** @group expertSetParam */ + def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + + /** @group expertGetParam */ + final def getMaxMemoryInMB: Int = $(maxMemoryInMB) + + /** @group expertSetParam */ + def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + + /** @group expertGetParam */ + final def getCacheNodeIds: Boolean = $(cacheNodeIds) + + /** @group expertSetParam */ + def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** @group expertGetParam */ + final def getCheckpointInterval: Int = $(checkpointInterval) + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity, + subsamplingRate: Double): OldStrategy = { + val strategy = OldStrategy.defaultStategy(oldAlgo) + strategy.impurity = oldImpurity + strategy.checkpointInterval = getCheckpointInterval + strategy.maxBins = getMaxBins + strategy.maxDepth = getMaxDepth + strategy.maxMemoryInMB = getMaxMemoryInMB + strategy.minInfoGain = getMinInfoGain + strategy.minInstancesPerNode = getMinInstancesPerNode + strategy.useNodeIdCache = getCacheNodeIds + strategy.numClasses = numClasses + strategy.categoricalFeaturesInfo = categoricalFeatures + strategy.subsamplingRate = subsamplingRate + strategy + } +} + +/** + * Parameters for Decision Tree-based classification algorithms. + */ +private[ml] trait TreeClassifierParams extends Params { + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "entropy" and "gini". + * (default = gini) + * @group param + */ + final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}", + (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase)) + + setDefault(impurity -> "gini") + + /** @group setParam */ + def setImpurity(value: String): this.type = set(impurity, value) + + /** @group getParam */ + final def getImpurity: String = $(impurity).toLowerCase + + /** Convert new impurity to old impurity. */ + private[ml] def getOldImpurity: OldImpurity = { + getImpurity match { + case "entropy" => OldEntropy + case "gini" => OldGini + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"TreeClassifierParams was given unrecognized impurity: $impurity.") + } + } +} + +private[ml] object TreeClassifierParams { + // These options should be lowercase. + final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase) +} + +/** + * Parameters for Decision Tree-based regression algorithms. + */ +private[ml] trait TreeRegressorParams extends Params { + + /** + * Criterion used for information gain calculation (case-insensitive). + * Supported: "variance". + * (default = variance) + * @group param + */ + final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" + + " information gain calculation (case-insensitive). Supported options:" + + s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}", + (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase)) + + setDefault(impurity -> "variance") + + /** @group setParam */ + def setImpurity(value: String): this.type = set(impurity, value) + + /** @group getParam */ + final def getImpurity: String = $(impurity).toLowerCase + + /** Convert new impurity to old impurity. */ + private[ml] def getOldImpurity: OldImpurity = { + getImpurity match { + case "variance" => OldVariance + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException( + s"TreeRegressorParams was given unrecognized impurity: $impurity") + } + } +} + +private[ml] object TreeRegressorParams { + // These options should be lowercase. + final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) +} + +/** + * :: DeveloperApi :: + * Parameters for Decision Tree-based ensemble algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { + + /** + * Fraction of the training data used for learning each decision tree, in range (0, 1]. + * (default = 1.0) + * @group param + */ + final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate", + "Fraction of the training data used for learning each decision tree, in range (0, 1].", + ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) + + setDefault(subsamplingRate -> 1.0) + + /** @group setParam */ + def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + + /** @group getParam */ + final def getSubsamplingRate: Double = $(subsamplingRate) + + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + + /** + * Create a Strategy instance to use with the old API. + * NOTE: The caller should set impurity and seed. + */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) + } +} + +/** + * :: DeveloperApi :: + * Parameters for Random Forest algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait RandomForestParams extends TreeEnsembleParams { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) + + /** + * The number of features to consider for splits at each tree node. + * Supported options: + * - "auto": Choose automatically for task: + * If numTrees == 1, set to "all." + * If numTrees > 1 (forest), set to "sqrt" for classification and + * to "onethird" for regression. + * - "all": use all features + * - "onethird": use 1/3 of the features + * - "sqrt": use sqrt(number of features) + * - "log2": use log2(number of features) + * (default = "auto") + * + * These various settings are based on the following references: + * - log2: tested in Breiman (2001) + * - sqrt: recommended by Breiman manual for random forests + * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest + * package. + * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]] + * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for + * random forests]] + * + * @group param + */ + final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy", + "The number of features to consider for splits at each tree node." + + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}", + (value: String) => + RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)) + + setDefault(numTrees -> 20, featureSubsetStrategy -> "auto") + + /** @group setParam */ + def setNumTrees(value: Int): this.type = set(numTrees, value) + + /** @group getParam */ + final def getNumTrees: Int = $(numTrees) + + /** @group setParam */ + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) + + /** @group getParam */ + final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase +} + +private[ml] object RandomForestParams { + // These options should be lowercase. + final val supportedFeatureSubsetStrategies: Array[String] = + Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) +} + +/** + * :: DeveloperApi :: + * Parameters for Gradient-Boosted Tree algorithms. + * + * Note: Marked as private and DeveloperApi since this may be made public in the future. + */ +@DeveloperApi +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { + + /** + * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each + * estimator. + * (default = 0.1) + * @group param + */ + final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." + + " learning rate) in interval (0, 1] for shrinking the contribution of each estimator", + ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) + + /* TODO: Add this doc when we add this param. SPARK-7132 + * Threshold for stopping early when runWithValidation is used. + * If the error rate on the validation input changes by less than the validationTol, + * then learning will stop early (before [[numIterations]]). + * This parameter is ignored when run is used. + * (default = 1e-5) + * @group param + */ + // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") + // validationTol -> 1e-5 + + setDefault(maxIter -> 20, stepSize -> 0.1) + + /** @group setParam */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + def setStepSize(value: Double): this.type = set(stepSize, value) + + /** @group getParam */ + final def getStepSize: Double = $(stepSize) + + /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ + private[ml] def getOldBoostingStrategy( + categoricalFeatures: Map[Int, Int], + oldAlgo: OldAlgo.Algo): OldBoostingStrategy = { + val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance) + // NOTE: The old API does not support "seed" so we ignore it. + new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize) + } + + /** Get old Gradient Boosting Loss type */ + private[ml] def getOldLossType: OldLoss +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
