Repository: spark Updated Branches: refs/heads/master 0f1f31d3a -> 354f8f11b
[SPARK-15096][ML] LogisticRegression MultiClassSummarizer numClasses can fail if no valid labels are found ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) Throw better exception when numClasses is empty and empty.max is thrown. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Add a new unit test, which calls histogram with empty numClasses. Author: [email protected] <[email protected]> Closes #12969 from wangmiao1981/logisticR. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/354f8f11 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/354f8f11 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/354f8f11 Branch: refs/heads/master Commit: 354f8f11bd4b20fa99bd67a98da3525fd3d75c81 Parents: 0f1f31d Author: [email protected] <[email protected]> Authored: Sat May 14 09:45:56 2016 +0100 Committer: Sean Owen <[email protected]> Committed: Sat May 14 09:45:56 2016 +0100 ---------------------------------------------------------------------- .../org/apache/spark/ml/classification/LogisticRegression.scala | 2 +- .../apache/spark/ml/classification/LogisticRegressionSuite.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/354f8f11/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 3e8040d..ffd03e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -745,7 +745,7 @@ private[classification] class MultiClassSummarizer extends Serializable { def countInvalid: Long = totalInvalidCnt /** @return The number of distinct labels in the input dataset. */ - def numClasses: Int = distinctMap.keySet.max + 1 + def numClasses: Int = if (distinctMap.isEmpty) 0 else distinctMap.keySet.max + 1 /** @return The weightSum of each label in the input dataset. */ def histogram: Array[Double] = { http://git-wip-us.apache.org/repos/asf/spark/blob/354f8f11/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index f127aa2..69650eb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -256,6 +256,10 @@ class LogisticRegressionSuite assert(summarizer4.countInvalid === 2) assert(summarizer4.numClasses === 4) + val summarizer5 = new MultiClassSummarizer + assert(summarizer5.histogram.isEmpty) + assert(summarizer5.numClasses === 0) + // small map merges large one val summarizerA = summarizer1.merge(summarizer2) assert(summarizerA.hashCode() === summarizer2.hashCode()) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
