Repository: spark Updated Branches: refs/heads/master ba5f81859 -> 93ef9b6a2
[SPARK-9622][ML] DecisionTreeRegressor: provide variance of prediction DecisionTreeRegressor will provide variance of prediction as a Double column. Author: Yanbo Liang <yblia...@gmail.com> Closes #8866 from yanboliang/spark-9622. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/93ef9b6a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/93ef9b6a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/93ef9b6a Branch: refs/heads/master Commit: 93ef9b6a2aa1830170cb101f191022f2dda62c41 Parents: ba5f818 Author: Yanbo Liang <yblia...@gmail.com> Authored: Mon Jan 4 13:32:14 2016 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Mon Jan 4 13:32:14 2016 -0800 ---------------------------------------------------------------------- .../ml/param/shared/SharedParamsCodeGen.scala | 1 + .../spark/ml/param/shared/sharedParams.scala | 15 ++++++++ .../ml/regression/DecisionTreeRegressor.scala | 36 ++++++++++++++++++-- .../org/apache/spark/ml/tree/treeParams.scala | 18 ++++++++++ .../regression/DecisionTreeRegressorSuite.scala | 26 +++++++++++++- 5 files changed, 92 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/93ef9b6a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index c7bca12..4aff749 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -44,6 +44,7 @@ private[shared] object SharedParamsCodeGen { " probabilities. Note: Not all models output well-calibrated probability estimates!" + " These probabilities should be treated as confidences, not precise probabilities", Some("\"probability\"")), + ParamDesc[String]("varianceCol", "Column name for the biased sample variance of prediction"), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), http://git-wip-us.apache.org/repos/asf/spark/blob/93ef9b6a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index cb2a060..c088c16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -139,6 +139,21 @@ private[ml] trait HasProbabilityCol extends Params { } /** + * Trait for shared param varianceCol. + */ +private[ml] trait HasVarianceCol extends Params { + + /** + * Param for Column name for the biased sample variance of prediction. + * @group param + */ + final val varianceCol: Param[String] = new Param[String](this, "varianceCol", "Column name for the biased sample variance of prediction") + + /** @group getParam */ + final def getVarianceCol: String = $(varianceCol) +} + +/** * Trait for shared param threshold (default: 0.5). */ private[ml] trait HasThreshold extends Params { http://git-wip-us.apache.org/repos/asf/spark/blob/93ef9b6a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 477030d..18c94f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector @@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ /** * :: Experimental :: @@ -40,7 +41,7 @@ import org.apache.spark.sql.DataFrame @Experimental final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] - with DecisionTreeParams with TreeRegressorParams { + with DecisionTreeRegressorParams { @Since("1.4.0") def this() = this(Identifiable.randomUID("dtr")) @@ -73,6 +74,9 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val override def setSeed(value: Long): this.type = super.setSeed(value) + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) + override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -113,7 +117,10 @@ final class DecisionTreeRegressionModel private[ml] ( override val rootNode: Node, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] - with DecisionTreeModel with Serializable { + with DecisionTreeModel with DecisionTreeRegressorParams with Serializable { + + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) require(rootNode != null, "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") @@ -129,6 +136,29 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).prediction } + /** We need to update this function if we ever add other impurity measures. */ + protected def predictVariance(features: Vector): Double = { + rootNode.predictImpl(features).impurityStats.calculate() + } + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + transformImpl(dataset) + } + + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val predictUDF = udf { (features: Vector) => predict(features) } + val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } + var output = dataset + if ($(predictionCol).nonEmpty) { + output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { + output = output.withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) + } + output + } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeRegressionModel = { copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent) http://git-wip-us.apache.org/repos/asf/spark/blob/93ef9b6a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 1da97db..7443097 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,9 +20,11 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.sql.types.{DoubleType, DataType, StructType} /** * Parameters for Decision Tree-based algorithms. @@ -256,6 +258,22 @@ private[ml] object TreeRegressorParams { final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) } +private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams + with TreeRegressorParams with HasVarianceCol { + + override protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { + SchemaUtils.appendColumn(newSchema, $(varianceCol), DoubleType) + } else { + newSchema + } + } +} + /** * Parameters for Decision Tree-based ensemble algorithms. * http://git-wip-us.apache.org/repos/asf/spark/blob/93ef9b6a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 6999a91..0b39af5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{Row, DataFrame} class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -73,6 +74,29 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex MLTestingUtils.checkCopy(model) } + test("predictVariance") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + .setPredictionCol("") + .setVarianceCol("variance") + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val model = dt.fit(df) + + val predictions = model.transform(df) + .select(model.getFeaturesCol, model.getVarianceCol) + .collect() + + predictions.foreach { case Row(features: Vector, variance: Double) => + val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate() + assert(variance === expectedVariance, + s"Expected variance $expectedVariance but got $variance.") + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org