Repository: spark Updated Branches: refs/heads/master cba69aeb4 -> 96028e36b
[SPARK-17139][ML][FOLLOW-UP] Add convenient method `asBinary` for casting to BinaryLogisticRegressionSummary ## What changes were proposed in this pull request? add an "asBinary" method to LogisticRegressionSummary for convenient casting to BinaryLogisticRegressionSummary. ## How was this patch tested? Testcase updated. Author: WeichenXu <weichen...@databricks.com> Closes #19072 from WeichenXu123/mlor_summary_as_binary. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/96028e36 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/96028e36 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/96028e36 Branch: refs/heads/master Commit: 96028e36b4d08427fdd94df55595849c2346ead4 Parents: cba69ae Author: WeichenXu <weichen...@databricks.com> Authored: Thu Aug 31 16:22:40 2017 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Thu Aug 31 16:22:40 2017 -0700 ---------------------------------------------------------------------- .../spark/ml/classification/LogisticRegression.scala | 11 +++++++++++ .../ml/classification/LogisticRegressionSuite.scala | 6 ++++++ project/MimaExcludes.scala | 1 + 3 files changed, 18 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/96028e36/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 1869d51..f491a67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1473,6 +1473,17 @@ sealed trait LogisticRegressionSummary extends Serializable { /** Returns weighted averaged f1-measure. */ @Since("2.3.0") def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0) + + /** + * Convenient method for casting to binary logistic regression summary. + * This method will throws an Exception if the summary is not a binary summary. + */ + @Since("2.3.0") + def asBinary: BinaryLogisticRegressionSummary = this match { + case b: BinaryLogisticRegressionSummary => b + case _ => + throw new RuntimeException("Cannot cast to a binary summary.") + } } /** http://git-wip-us.apache.org/repos/asf/spark/blob/96028e36/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 6649fa4..6bf1253 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -256,6 +256,7 @@ class LogisticRegressionSuite val blorModel = lr.fit(smallBinaryDataset) assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) + assert(blorModel.summary.asBinary.isInstanceOf[BinaryLogisticRegressionSummary]) assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset) @@ -265,6 +266,11 @@ class LogisticRegressionSuite mlorModel.binarySummary } } + withClue("cannot cast summary to binary summary multiclass model") { + intercept[RuntimeException] { + mlorModel.summary.asBinary + } + } val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset) assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary]) http://git-wip-us.apache.org/repos/asf/spark/blob/96028e36/project/MimaExcludes.scala ---------------------------------------------------------------------- diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index eecda26..27e4183 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -62,6 +62,7 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedRecall"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedPrecision"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.asBinary"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$_setter_$org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics_=") ) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org