Repository: spark Updated Branches: refs/heads/master b1b4ee7f3 -> 518ab5101
[SPARK-10991][ML] logistic regression training summary handle empty prediction col LogisticRegression training summary should still function if the predictionCol is set to an empty string or otherwise unset (related too https://issues.apache.org/jira/browse/SPARK-9718 ) Author: Holden Karau <hol...@pigscanfly.ca> Author: Holden Karau <hol...@us.ibm.com> Closes #9037 from holdenk/SPARK-10991-LogisticRegressionTrainingSummary-handle-empty-prediction-col. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/518ab510 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/518ab510 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/518ab510 Branch: refs/heads/master Commit: 518ab5101073ee35d62e33c8f7281a1e6342101e Parents: b1b4ee7 Author: Holden Karau <hol...@pigscanfly.ca> Authored: Fri Dec 11 02:35:53 2015 -0500 Committer: DB Tsai <d...@netflix.com> Committed: Fri Dec 11 02:35:53 2015 -0500 ---------------------------------------------------------------------- .../ml/classification/LogisticRegression.scala | 20 ++++++++++++++++++-- .../LogisticRegressionSuite.scala | 11 +++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/518ab510/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 19cc323..486043e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -389,9 +389,10 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.unpersist() val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept)) + val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol() val logRegSummary = new BinaryLogisticRegressionTrainingSummary( - model.transform(dataset), - $(probabilityCol), + summaryModel.transform(dataset), + probabilityColName, $(labelCol), $(featuresCol), objectiveHistory) @@ -469,6 +470,21 @@ class LogisticRegressionModel private[ml] ( new NullPointerException()) } + /** + * If the probability column is set returns the current model and probability column, + * otherwise generates a new column and sets it as the probability column on a new copy + * of the current model. + */ + private[classification] def findSummaryModelAndProbabilityCol(): + (LogisticRegressionModel, String) = { + $(probabilityCol) match { + case "" => + val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString() + (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName) + case p => (this, p) + } + } + private[classification] def setSummary( summary: LogisticRegressionTrainingSummary): this.type = { this.trainingSummary = Some(summary) http://git-wip-us.apache.org/repos/asf/spark/blob/518ab510/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index a9a6ff8..1087afb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -99,6 +99,17 @@ class LogisticRegressionSuite assert(model.hasParent) } + test("empty probabilityCol") { + val lr = new LogisticRegression().setProbabilityCol("") + val model = lr.fit(dataset) + assert(model.hasSummary) + // Validate that we re-insert a probability column for evaluation + val fieldNames = model.summary.predictions.schema.fieldNames + assert((dataset.schema.fieldNames.toSet).subsetOf( + fieldNames.toSet)) + assert(fieldNames.exists(s => s.startsWith("probability_"))) + } + test("setThreshold, getThreshold") { val lr = new LogisticRegression // default --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org