Repository: spark Updated Branches: refs/heads/master 12c360c05 -> 5ee72454d
[SPARK-14852][ML] refactored GLM summary into training, non-training summaries ## What changes were proposed in this pull request? This splits GeneralizedLinearRegressionSummary into 2 summary types: * GeneralizedLinearRegressionSummary, which does not store info from fitting (diagInvAtWA) * GeneralizedLinearRegressionTrainingSummary, which is a subclass of GeneralizedLinearRegressionSummary and stores info from fitting This also add a method evaluate() which can produce a GeneralizedLinearRegressionSummary on a new dataset. The summary no longer provides the model itself as a public val. Also: * Fixes bug where GeneralizedLinearRegressionTrainingSummary was created with model, not summaryModel. * Adds hasSummary method. * Renames findSummaryModelAndPredictionCol -> getSummaryModel and simplifies that method. * In summary, extract values from model immediately in case user later changes those (e.g., predictionCol). * Pardon the style fixes; that is IntelliJ being obnoxious. ## How was this patch tested? Existing unit tests + updated test for evaluate and hasSummary Author: Joseph K. Bradley <[email protected]> Closes #12624 from jkbradley/model-summary-api. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5ee72454 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5ee72454 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5ee72454 Branch: refs/heads/master Commit: 5ee72454df21ef4668c855134627d0cdf5d35132 Parents: 12c360c Author: Joseph K. Bradley <[email protected]> Authored: Thu Apr 28 11:22:13 2016 -0700 Committer: Joseph K. Bradley <[email protected]> Committed: Thu Apr 28 11:22:13 2016 -0700 ---------------------------------------------------------------------- .../GeneralizedLinearRegression.scala | 156 ++++++++++++------- .../GeneralizedLinearRegressionSuite.scala | 14 ++ 2 files changed, 115 insertions(+), 55 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5ee72454/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index dcf69af..bf9d3ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + /** * Params for Generalized Linear Regression. */ @@ -81,6 +82,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam /** * Param for link prediction (linear predictor) column name. * Default is empty, which means we do not output link prediction. + * * @group param */ @Since("2.0.0") @@ -144,6 +146,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the value of param [[family]]. * Default is "gaussian". + * * @group setParam */ @Since("2.0.0") @@ -152,6 +155,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the value of param [[link]]. + * * @group setParam */ @Since("2.0.0") @@ -160,6 +164,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets if we should fit the intercept. * Default is true. + * * @group setParam */ @Since("2.0.0") @@ -168,6 +173,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the maximum number of iterations. * Default is 25 if the solver algorithm is "irls". + * * @group setParam */ @Since("2.0.0") @@ -177,6 +183,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val * Sets the convergence tolerance of iterations. * Smaller value will lead to higher accuracy with the cost of more iterations. * Default is 1E-6. + * * @group setParam */ @Since("2.0.0") @@ -190,6 +197,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val * 0.5 * regParam * L2norm(coefficients)^2 * }}} * Default is 0.0. + * * @group setParam */ @Since("2.0.0") @@ -200,6 +208,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val * Sets the value of param [[weightCol]]. * If this is not set or empty, we treat all instance weights as 1.0. * Default is empty, so all instances have weight one. + * * @group setParam */ @Since("2.0.0") @@ -209,6 +218,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the solver algorithm used for optimization. * Currently only support "irls" which is also the default solver. + * * @group setParam */ @Since("2.0.0") @@ -217,6 +227,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val /** * Sets the link prediction (linear predictor) column name. + * * @group setParam */ @Since("2.0.0") @@ -256,15 +267,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - val trainingSummary = new GeneralizedLinearRegressionSummary( - summaryModel.transform(dataset), - predictionColName, - model, - wlsModel.diagInvAtWA.toArray, - 1, - getSolver) + val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, + wlsModel.diagInvAtWA.toArray, 1, getSolver) return model.setSummary(trainingSummary) } @@ -277,16 +281,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - val trainingSummary = new GeneralizedLinearRegressionSummary( - summaryModel.transform(dataset), - predictionColName, - model, - irlsModel.diagInvAtWA.toArray, - irlsModel.numIterations, - getSolver) - + val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, + irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) model.setSummary(trainingSummary) } @@ -363,6 +359,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine /** * A description of the error distribution to be used in the model. + * * @param name the name of the family. */ private[ml] abstract class Family(val name: String) extends Serializable { @@ -381,6 +378,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine /** * Akaike's 'An Information Criterion'(AIC) value of the family for a given dataset. + * * @param predictions an RDD of (y, mu, weight) of instances in evaluation dataset * @param deviance the deviance for the fitted model in evaluation dataset * @param numInstances number of instances in evaluation dataset @@ -400,6 +398,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine /** * Gets the [[Family]] object from its name. + * * @param name family name: "gaussian", "binomial", "poisson" or "gamma". */ def fromName(name: String): Family = { @@ -579,6 +578,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine * A description of the link function to be used in the model. * The link function provides the relationship between the linear predictor * and the mean of the distribution function. + * * @param name the name of link function. */ private[ml] abstract class Link(val name: String) extends Serializable { @@ -597,6 +597,7 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine /** * Gets the [[Link]] object from its name. + * * @param name link name: "identity", "logit", "log", * "inverse", "probit", "cloglog" or "sqrt". */ @@ -694,6 +695,7 @@ class GeneralizedLinearRegressionModel private[ml] ( /** * Sets the link prediction (linear predictor) column name. + * * @group setParam */ @Since("2.0.0") @@ -736,39 +738,39 @@ class GeneralizedLinearRegressionModel private[ml] ( if ($(linkPredictionCol).nonEmpty) { output = output.withColumn($(linkPredictionCol), predictLinkUDF(col($(featuresCol)))) } - output.toDF + output.toDF() } - private var trainingSummary: Option[GeneralizedLinearRegressionSummary] = None + private var trainingSummary: Option[GeneralizedLinearRegressionTrainingSummary] = None /** * Gets R-like summary of model on training set. An exception is - * thrown if `trainingSummary == None`. + * thrown if there is no summary available. */ @Since("2.0.0") - def summary: GeneralizedLinearRegressionSummary = trainingSummary.getOrElse { + def summary: GeneralizedLinearRegressionTrainingSummary = trainingSummary.getOrElse { throw new SparkException( "No training summary available for this GeneralizedLinearRegressionModel") } - private[regression] def setSummary(summary: GeneralizedLinearRegressionSummary): this.type = { + /** + * Indicates if [[summary]] is available. + */ + @Since("2.0.0") + def hasSummary: Boolean = trainingSummary.nonEmpty + + private[regression] + def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = { this.trainingSummary = Some(summary) this } /** - * If the prediction column is set returns the current model and prediction column, - * otherwise generates a new column and sets it as the prediction column on a new copy - * of the current model. + * Evaluate the model on the given dataset, returning a summary of the results. */ - private[regression] def findSummaryModelAndPredictionCol() - : (GeneralizedLinearRegressionModel, String) = { - $(predictionCol) match { - case "" => - val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString - (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) - case p => (this, p) - } + @Since("2.0.0") + def evaluate(dataset: Dataset[_]): GeneralizedLinearRegressionSummary = { + new GeneralizedLinearRegressionSummary(dataset, this) } @Since("2.0.0") @@ -834,36 +836,55 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr /** * :: Experimental :: - * Summarizing Generalized Linear regression Fits. + * Summary of [[GeneralizedLinearRegression]] model and predictions. * - * @param predictions predictions output by the model's `transform` method - * @param predictionCol field in "predictions" which gives the prediction value of each instance - * @param model the model that should be summarized - * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration - * @param numIterations number of iterations - * @param solver the solver algorithm used for model training + * @param dataset Dataset to be summarized. + * @param origModel Model to be summarized. This is copied to create an internal + * model which cannot be modified from outside. */ @Since("2.0.0") @Experimental class GeneralizedLinearRegressionSummary private[regression] ( - @Since("2.0.0") @transient val predictions: DataFrame, - @Since("2.0.0") val predictionCol: String, - @Since("2.0.0") val model: GeneralizedLinearRegressionModel, - private val diagInvAtWA: Array[Double], - @Since("2.0.0") val numIterations: Int, - @Since("2.0.0") val solver: String) extends Serializable { + dataset: Dataset[_], + origModel: GeneralizedLinearRegressionModel) extends Serializable { import GeneralizedLinearRegression._ - private lazy val family = Family.fromName(model.getFamily) - private lazy val link = if (model.isDefined(model.getParam("link"))) { + /** + * Field in "predictions" which gives the prediction value of each instance. + * This is set to a new column name if the original model's `predictionCol` is not set. + */ + @Since("2.0.0") + val predictionCol: String = { + if (origModel.isDefined(origModel.predictionCol) && origModel.getPredictionCol != "") { + origModel.getPredictionCol + } else { + "prediction_" + java.util.UUID.randomUUID.toString + } + } + + /** + * Private copy of model to ensure Params are not modified outside this class. + * Coefficients is not a deep copy, but that is acceptable. + * + * NOTE: [[predictionCol]] must be set correctly before the value of [[model]] is set, + * and [[model]] must be set before [[predictions]] is set! + */ + protected val model: GeneralizedLinearRegressionModel = + origModel.copy(ParamMap.empty).setPredictionCol(predictionCol) + + /** predictions output by the model's `transform` method */ + @Since("2.0.0") @transient val predictions: DataFrame = model.transform(dataset) + + private[regression] lazy val family: Family = Family.fromName(model.getFamily) + private[regression] lazy val link: Link = if (model.isDefined(model.link)) { Link.fromName(model.getLink) } else { family.defaultLink } /** Number of instances in DataFrame predictions */ - private lazy val numInstances: Long = predictions.count() + private[regression] lazy val numInstances: Long = predictions.count() /** The numeric rank of the fitted linear model */ @Since("2.0.0") @@ -891,7 +912,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( numInstances } - private lazy val devianceResiduals: DataFrame = { + private[regression] lazy val devianceResiduals: DataFrame = { val drUDF = udf { (y: Double, mu: Double, weight: Double) => val r = math.sqrt(math.max(family.deviance(y, mu, weight), 0.0)) if (y > mu) r else -1.0 * r @@ -901,19 +922,19 @@ class GeneralizedLinearRegressionSummary private[regression] ( drUDF(col(model.getLabelCol), col(predictionCol), w).as("devianceResiduals")) } - private lazy val pearsonResiduals: DataFrame = { + private[regression] lazy val pearsonResiduals: DataFrame = { val prUDF = udf { mu: Double => family.variance(mu) } val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) predictions.select(col(model.getLabelCol).minus(col(predictionCol)) .multiply(sqrt(w)).divide(sqrt(prUDF(col(predictionCol)))).as("pearsonResiduals")) } - private lazy val workingResiduals: DataFrame = { + private[regression] lazy val workingResiduals: DataFrame = { val wrUDF = udf { (y: Double, mu: Double) => (y - mu) * link.deriv(mu) } predictions.select(wrUDF(col(model.getLabelCol), col(predictionCol)).as("workingResiduals")) } - private lazy val responseResiduals: DataFrame = { + private[regression] lazy val responseResiduals: DataFrame = { predictions.select(col(model.getLabelCol).minus(col(predictionCol)).as("responseResiduals")) } @@ -925,6 +946,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** * Get the residuals of the fitted model by type. + * * @param residualsType The type of residuals which should be returned. * Supported options: deviance, pearson, working and response. */ @@ -996,6 +1018,30 @@ class GeneralizedLinearRegressionSummary private[regression] ( } family.aic(t, deviance, numInstances, weightSum) + 2 * rank } +} + +/** + * :: Experimental :: + * Summary of [[GeneralizedLinearRegression]] fitting and model. + * + * @param dataset Dataset to be summarized. + * @param origModel Model to be summarized. This is copied to create an internal + * model which cannot be modified from outside. + * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration + * @param numIterations number of iterations + * @param solver the solver algorithm used for model training + */ +@Since("2.0.0") +@Experimental +class GeneralizedLinearRegressionTrainingSummary private[regression] ( + dataset: Dataset[_], + origModel: GeneralizedLinearRegressionModel, + private val diagInvAtWA: Array[Double], + @Since("2.0.0") val numIterations: Int, + @Since("2.0.0") val solver: String) + extends GeneralizedLinearRegressionSummary(dataset, origModel) with Serializable { + + import GeneralizedLinearRegression._ /** * Standard error of estimated coefficients and intercept. http://git-wip-us.apache.org/repos/asf/spark/blob/5ee72454/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 0b5e77a..e4c9a3b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -603,7 +603,9 @@ class GeneralizedLinearRegressionSuite val residualDegreeOfFreedomR = 1 val aicR = 18.783 + assert(model.hasSummary) val summary = model.summary + assert(summary.isInstanceOf[GeneralizedLinearRegressionTrainingSummary]) val devianceResiduals = summary.residuals() .select(col("devianceResiduals")) @@ -643,6 +645,18 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) assert(summary.solver === "irls") + + val summary2: GeneralizedLinearRegressionSummary = model.evaluate(datasetWithWeight) + assert(summary.predictions.columns.toSet === summary2.predictions.columns.toSet) + assert(summary.predictionCol === summary2.predictionCol) + assert(summary.rank === summary2.rank) + assert(summary.degreesOfFreedom === summary2.degreesOfFreedom) + assert(summary.residualDegreeOfFreedom === summary2.residualDegreeOfFreedom) + assert(summary.residualDegreeOfFreedomNull === summary2.residualDegreeOfFreedomNull) + assert(summary.nullDeviance === summary2.nullDeviance) + assert(summary.deviance === summary2.deviance) + assert(summary.dispersion === summary2.dispersion) + assert(summary.aic === summary2.aic) } test("glm summary: binomial family with weight") { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
