This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 6834f46 [SPARK-31681][ML][PYSPARK] Python multiclass logistic regression evaluate should return LogisticRegressionSummary 6834f46 is described below commit 6834f4691b3e2603d410bfe24f0db0b6e3a36446 Author: Huaxin Gao <huax...@us.ibm.com> AuthorDate: Thu May 14 10:54:35 2020 -0500 [SPARK-31681][ML][PYSPARK] Python multiclass logistic regression evaluate should return LogisticRegressionSummary ### What changes were proposed in this pull request? Return LogisticRegressionSummary for multiclass logistic regression evaluate in PySpark ### Why are the changes needed? Currently we have ``` since("2.0.0") def evaluate(self, dataset): if not isinstance(dataset, DataFrame): raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) java_blr_summary = self._call_java("evaluate", dataset) return BinaryLogisticRegressionSummary(java_blr_summary) ``` we should return LogisticRegressionSummary for multiclass logistic regression ### Does this PR introduce _any_ user-facing change? Yes return LogisticRegressionSummary instead of BinaryLogisticRegressionSummary for multiclass logistic regression in Python ### How was this patch tested? unit test Closes #28503 from huaxingao/lr_summary. Authored-by: Huaxin Gao <huax...@us.ibm.com> Signed-off-by: Sean Owen <sro...@gmail.com> (cherry picked from commit e10516ae63cfc58f2d493e4d3f19940d45c8f033) Signed-off-by: Sean Owen <sro...@gmail.com> --- python/pyspark/ml/classification.py | 5 ++++- python/pyspark/ml/tests/test_training_summary.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 1436b78..424e16a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -831,7 +831,10 @@ class LogisticRegressionModel(JavaProbabilisticClassificationModel, _LogisticReg if not isinstance(dataset, DataFrame): raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) java_blr_summary = self._call_java("evaluate", dataset) - return BinaryLogisticRegressionSummary(java_blr_summary) + if self.numClasses <= 2: + return BinaryLogisticRegressionSummary(java_blr_summary) + else: + return LogisticRegressionSummary(java_blr_summary) class LogisticRegressionSummary(JavaWrapper): diff --git a/python/pyspark/ml/tests/test_training_summary.py b/python/pyspark/ml/tests/test_training_summary.py index 1d19ebf..b505409 100644 --- a/python/pyspark/ml/tests/test_training_summary.py +++ b/python/pyspark/ml/tests/test_training_summary.py @@ -21,7 +21,8 @@ import unittest if sys.version > '3': basestring = str -from pyspark.ml.classification import LogisticRegression +from pyspark.ml.classification import BinaryLogisticRegressionSummary, LogisticRegression, \ + LogisticRegressionSummary from pyspark.ml.clustering import BisectingKMeans, GaussianMixture, KMeans from pyspark.ml.linalg import Vectors from pyspark.ml.regression import GeneralizedLinearRegression, LinearRegression @@ -149,6 +150,7 @@ class TrainingSummaryTest(SparkSessionTestCase): # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) + self.assertTrue(isinstance(sameSummary, BinaryLogisticRegressionSummary)) self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) def test_multiclass_logistic_regression_summary(self): @@ -187,6 +189,8 @@ class TrainingSummaryTest(SparkSessionTestCase): # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) + self.assertTrue(isinstance(sameSummary, LogisticRegressionSummary)) + self.assertFalse(isinstance(sameSummary, BinaryLogisticRegressionSummary)) self.assertAlmostEqual(sameSummary.accuracy, s.accuracy) def test_gaussian_mixture_summary(self): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org