Repository: spark Updated Branches: refs/heads/master 85200c09a -> b4574e387
[SPARK-12908][ML] Add warning message for LogisticRegression for potential converge issue When all labels are the same, it's a dangerous ground for LogisticRegression without intercept to converge. GLMNET doesn't support this case, and will just exit. GLM can train, but will have a warning message saying the algorithm doesn't converge. Author: DB Tsai <d...@netflix.com> Closes #10862 from dbtsai/add-tests. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b4574e38 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b4574e38 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b4574e38 Branch: refs/heads/master Commit: b4574e387d0124667bdbb35f8c7c3e2065b14ba9 Parents: 85200c0 Author: DB Tsai <d...@netflix.com> Authored: Thu Jan 21 17:24:48 2016 -0800 Committer: DB Tsai <d...@netflix.com> Committed: Thu Jan 21 17:24:48 2016 -0800 ---------------------------------------------------------------------- .../apache/spark/ml/classification/LogisticRegression.scala | 8 ++++++++ 1 file changed, 8 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b4574e38/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index dad8dfc..c98a78a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -300,6 +300,14 @@ class LogisticRegression @Since("1.2.0") ( s"training is not needed.") (Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double]) } else { + if (!$(fitIntercept) && numClasses == 2 && histogram(0) == 0.0) { + logWarning(s"All labels are one and fitIntercept=false. It's a dangerous ground, " + + s"so the algorithm may not converge.") + } else if (!$(fitIntercept) && numClasses == 1) { + logWarning(s"All labels are zero and fitIntercept=false. It's a dangerous ground, " + + s"so the algorithm may not converge.") + } + val featuresMean = summarizer.mean.toArray val featuresStd = summarizer.variance.toArray.map(math.sqrt) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org