Repository: spark
Updated Branches:
  refs/heads/branch-2.2 2b4bd7910 -> 0d4ef2f69


[SPARK-21818][ML][MLLIB] Fix bug of MultivariateOnlineSummarizer.variance 
generate negative result

Because of numerical error, MultivariateOnlineSummarizer.variance is possible 
to generate negative variance.

**This is a serious bug because many algos in MLLib**
**use stddev computed from** `sqrt(variance)`
**it will generate NaN and crash the whole algorithm.**

we can reproduce this bug use the following code:
```
    val summarizer1 = (new MultivariateOnlineSummarizer)
      .add(Vectors.dense(3.0), 0.7)
    val summarizer2 = (new MultivariateOnlineSummarizer)
      .add(Vectors.dense(3.0), 0.4)
    val summarizer3 = (new MultivariateOnlineSummarizer)
      .add(Vectors.dense(3.0), 0.5)
    val summarizer4 = (new MultivariateOnlineSummarizer)
      .add(Vectors.dense(3.0), 0.4)

    val summarizer = summarizer1
      .merge(summarizer2)
      .merge(summarizer3)
      .merge(summarizer4)

    println(summarizer.variance(0))
```
This PR fix the bugs in `mllib.stat.MultivariateOnlineSummarizer.variance` and 
`ml.stat.SummarizerBuffer.variance`, and several places in 
`WeightedLeastSquares`

test cases added.

Author: WeichenXu <weichenxu...@outlook.com>

Closes #19029 from WeichenXu123/fix_summarizer_var_bug.

(cherry picked from commit 0456b4050817e64f27824720e695bbfff738d474)
Signed-off-by: Sean Owen <so...@cloudera.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0d4ef2f6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0d4ef2f6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0d4ef2f6

Branch: refs/heads/branch-2.2
Commit: 0d4ef2f690e378cade0a3ec84d535a535dc20dfc
Parents: 2b4bd79
Author: WeichenXu <weichenxu...@outlook.com>
Authored: Mon Aug 28 07:41:42 2017 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Mon Aug 28 08:00:29 2017 +0100

----------------------------------------------------------------------
 .../spark/ml/optim/WeightedLeastSquares.scala     | 12 +++++++++---
 .../mllib/stat/MultivariateOnlineSummarizer.scala |  5 +++--
 .../stat/MultivariateOnlineSummarizerSuite.scala  | 18 ++++++++++++++++++
 3 files changed, 30 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0d4ef2f6/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala 
b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 56ab967..c5c9c8e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -440,7 +440,11 @@ private[ml] object WeightedLeastSquares {
     /**
      * Weighted population standard deviation of labels.
      */
-    def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
+    def bStd: Double = {
+      // We prevent variance from negative value caused by numerical error.
+      val variance = math.max(bbSum / wSum - bBar * bBar, 0.0)
+      math.sqrt(variance)
+    }
 
     /**
      * Weighted mean of (label * features).
@@ -471,7 +475,8 @@ private[ml] object WeightedLeastSquares {
       while (i < triK) {
         val l = j - 2
         val aw = aSum(l) / wSum
-        std(l) = math.sqrt(aaValues(i) / wSum - aw * aw)
+        // We prevent variance from negative value caused by numerical error.
+        std(l) = math.sqrt(math.max(aaValues(i) / wSum - aw * aw, 0.0))
         i += j
         j += 1
       }
@@ -489,7 +494,8 @@ private[ml] object WeightedLeastSquares {
       while (i < triK) {
         val l = j - 2
         val aw = aSum(l) / wSum
-        variance(l) = aaValues(i) / wSum - aw * aw
+        // We prevent variance from negative value caused by numerical error.
+        variance(l) = math.max(aaValues(i) / wSum - aw * aw, 0.0)
         i += j
         j += 1
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/0d4ef2f6/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 7dc0c45..8121880 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -213,8 +213,9 @@ class MultivariateOnlineSummarizer extends 
MultivariateStatisticalSummary with S
       var i = 0
       val len = currM2n.length
       while (i < len) {
-        realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * 
weightSum(i) *
-          (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
+        // We prevent variance from negative value caused by numerical error.
+        realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * 
weightSum(i) *
+          (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
         i += 1
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/0d4ef2f6/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 797e84f..c6466bc 100644
--- 
a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends 
SparkFunSuite {
     assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
     assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
   }
+
+  test ("test zero variance (SPARK-21818)") {
+    val summarizer1 = (new MultivariateOnlineSummarizer)
+      .add(Vectors.dense(3.0), 0.7)
+    val summarizer2 = (new MultivariateOnlineSummarizer)
+      .add(Vectors.dense(3.0), 0.4)
+    val summarizer3 = (new MultivariateOnlineSummarizer)
+      .add(Vectors.dense(3.0), 0.5)
+    val summarizer4 = (new MultivariateOnlineSummarizer)
+      .add(Vectors.dense(3.0), 0.4)
+
+    val summarizer = summarizer1
+      .merge(summarizer2)
+      .merge(summarizer3)
+      .merge(summarizer4)
+
+    assert(summarizer.variance(0) >= 0.0)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to