Repository: spark Updated Branches: refs/heads/branch-2.0 2d6f3bb4d -> d305f7227
[SPARK-15096][ML] LogisticRegression MultiClassSummarizer numClasses can fail if no valid labels are found ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) Throw better exception when numClasses is empty and empty.max is thrown. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Add a new unit test, which calls histogram with empty numClasses. Author: [email protected] <[email protected]> Closes #12969 from wangmiao1981/logisticR. (cherry picked from commit 354f8f11bd4b20fa99bd67a98da3525fd3d75c81) Signed-off-by: Sean Owen <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d305f722 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d305f722 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d305f722 Branch: refs/heads/branch-2.0 Commit: d305f72275255f8d21ebbe62b545ac663d617f3b Parents: 2d6f3bb Author: [email protected] <[email protected]> Authored: Sat May 14 09:45:56 2016 +0100 Committer: Sean Owen <[email protected]> Committed: Sat May 14 09:46:03 2016 +0100 ---------------------------------------------------------------------- .../org/apache/spark/ml/classification/LogisticRegression.scala | 2 +- .../apache/spark/ml/classification/LogisticRegressionSuite.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d305f722/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index d2d4e24..62d6897 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -744,7 +744,7 @@ private[classification] class MultiClassSummarizer extends Serializable { def countInvalid: Long = totalInvalidCnt /** @return The number of distinct labels in the input dataset. */ - def numClasses: Int = distinctMap.keySet.max + 1 + def numClasses: Int = if (distinctMap.isEmpty) 0 else distinctMap.keySet.max + 1 /** @return The weightSum of each label in the input dataset. */ def histogram: Array[Double] = { http://git-wip-us.apache.org/repos/asf/spark/blob/d305f722/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index f127aa2..69650eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -256,6 +256,10 @@ class LogisticRegressionSuite assert(summarizer4.countInvalid === 2) assert(summarizer4.numClasses === 4) + val summarizer5 = new MultiClassSummarizer + assert(summarizer5.histogram.isEmpty) + assert(summarizer5.numClasses === 0) + // small map merges large one val summarizerA = summarizer1.merge(summarizer2) assert(summarizerA.hashCode() === summarizer2.hashCode()) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
