Repository: spark Updated Branches: refs/heads/branch-2.2 2b4bd7910 -> 0d4ef2f69
[SPARK-21818][ML][MLLIB] Fix bug of MultivariateOnlineSummarizer.variance generate negative result Because of numerical error, MultivariateOnlineSummarizer.variance is possible to generate negative variance. **This is a serious bug because many algos in MLLib** **use stddev computed from** `sqrt(variance)` **it will generate NaN and crash the whole algorithm.** we can reproduce this bug use the following code: ``` val summarizer1 = (new MultivariateOnlineSummarizer) .add(Vectors.dense(3.0), 0.7) val summarizer2 = (new MultivariateOnlineSummarizer) .add(Vectors.dense(3.0), 0.4) val summarizer3 = (new MultivariateOnlineSummarizer) .add(Vectors.dense(3.0), 0.5) val summarizer4 = (new MultivariateOnlineSummarizer) .add(Vectors.dense(3.0), 0.4) val summarizer = summarizer1 .merge(summarizer2) .merge(summarizer3) .merge(summarizer4) println(summarizer.variance(0)) ``` This PR fix the bugs in `mllib.stat.MultivariateOnlineSummarizer.variance` and `ml.stat.SummarizerBuffer.variance`, and several places in `WeightedLeastSquares` test cases added. Author: WeichenXu <weichenxu...@outlook.com> Closes #19029 from WeichenXu123/fix_summarizer_var_bug. (cherry picked from commit 0456b4050817e64f27824720e695bbfff738d474) Signed-off-by: Sean Owen <so...@cloudera.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0d4ef2f6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0d4ef2f6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0d4ef2f6 Branch: refs/heads/branch-2.2 Commit: 0d4ef2f690e378cade0a3ec84d535a535dc20dfc Parents: 2b4bd79 Author: WeichenXu <weichenxu...@outlook.com> Authored: Mon Aug 28 07:41:42 2017 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Mon Aug 28 08:00:29 2017 +0100 ---------------------------------------------------------------------- .../spark/ml/optim/WeightedLeastSquares.scala | 12 +++++++++--- .../mllib/stat/MultivariateOnlineSummarizer.scala | 5 +++-- .../stat/MultivariateOnlineSummarizerSuite.scala | 18 ++++++++++++++++++ 3 files changed, 30 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/0d4ef2f6/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 56ab967..c5c9c8e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -440,7 +440,11 @@ private[ml] object WeightedLeastSquares { /** * Weighted population standard deviation of labels. */ - def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar) + def bStd: Double = { + // We prevent variance from negative value caused by numerical error. + val variance = math.max(bbSum / wSum - bBar * bBar, 0.0) + math.sqrt(variance) + } /** * Weighted mean of (label * features). @@ -471,7 +475,8 @@ private[ml] object WeightedLeastSquares { while (i < triK) { val l = j - 2 val aw = aSum(l) / wSum - std(l) = math.sqrt(aaValues(i) / wSum - aw * aw) + // We prevent variance from negative value caused by numerical error. + std(l) = math.sqrt(math.max(aaValues(i) / wSum - aw * aw, 0.0)) i += j j += 1 } @@ -489,7 +494,8 @@ private[ml] object WeightedLeastSquares { while (i < triK) { val l = j - 2 val aw = aSum(l) / wSum - variance(l) = aaValues(i) / wSum - aw * aw + // We prevent variance from negative value caused by numerical error. + variance(l) = math.max(aaValues(i) / wSum - aw * aw, 0.0) i += j j += 1 } http://git-wip-us.apache.org/repos/asf/spark/blob/0d4ef2f6/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 7dc0c45..8121880 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -213,8 +213,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S var i = 0 val len = currM2n.length while (i < len) { - realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * - (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator + // We prevent variance from negative value caused by numerical error. + realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) * + (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0) i += 1 } } http://git-wip-us.apache.org/repos/asf/spark/blob/0d4ef2f6/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 797e84f..c6466bc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite { assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14) assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14) } + + test ("test zero variance (SPARK-21818)") { + val summarizer1 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(3.0), 0.7) + val summarizer2 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(3.0), 0.4) + val summarizer3 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(3.0), 0.5) + val summarizer4 = (new MultivariateOnlineSummarizer) + .add(Vectors.dense(3.0), 0.4) + + val summarizer = summarizer1 + .merge(summarizer2) + .merge(summarizer3) + .merge(summarizer4) + + assert(summarizer.variance(0) >= 0.0) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org