Repository: spark
Updated Branches:
  refs/heads/master 9bb35c5b5 -> 9753835cf


[SPARK-12230][ML] WeightedLeastSquares.fit() should handle division by zero 
properly if standard deviation of target variable is zero.

This fixes the behavior of WeightedLeastSquars.fit() when the standard 
deviation of the target variable is zero. If the fitIntercept is true, there is 
no need to train.

Author: Imran Younus <iyou...@us.ibm.com>

Closes #10274 from iyounus/SPARK-12230_bug_fix_in_weighted_least_squares.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9753835c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9753835c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9753835c

Branch: refs/heads/master
Commit: 9753835cf3acc135e61bf668223046e29306c80d
Parents: 9bb35c5
Author: Imran Younus <iyou...@us.ibm.com>
Authored: Wed Jan 20 11:16:59 2016 -0800
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Wed Jan 20 11:16:59 2016 -0800

----------------------------------------------------------------------
 .../spark/ml/optim/WeightedLeastSquares.scala   | 21 +++++-
 .../ml/optim/WeightedLeastSquaresSuite.scala    | 69 ++++++++++++++++++--
 2 files changed, 83 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9753835c/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala 
b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 8617722..797870e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -86,6 +86,24 @@ private[ml] class WeightedLeastSquares(
     val aaBar = summary.aaBar
     val aaValues = aaBar.values
 
+    if (bStd == 0) {
+      if (fitIntercept) {
+        logWarning(s"The standard deviation of the label is zero, so the 
coefficients will be " +
+          s"zeros and the intercept will be the mean of the label; as a 
result, " +
+          s"training is not needed.")
+        val coefficients = new DenseVector(Array.ofDim(k-1))
+        val intercept = bBar
+        val diagInvAtWA = new DenseVector(Array(0D))
+        return new WeightedLeastSquaresModel(coefficients, intercept, 
diagInvAtWA)
+      } else {
+        require(!(regParam > 0.0 && standardizeLabel),
+          "The standard deviation of the label is zero. " +
+            "Model cannot be regularized with standardization=true")
+        logWarning(s"The standard deviation of the label is zero. " +
+          "Consider setting fitIntercept=true.")
+      }
+    }
+
     // add regularization to diagonals
     var i = 0
     var j = 2
@@ -94,8 +112,7 @@ private[ml] class WeightedLeastSquares(
       if (standardizeFeatures) {
         lambda *= aVar(j - 2)
       }
-      if (standardizeLabel) {
-        // TODO: handle the case when bStd = 0
+      if (standardizeLabel && bStd != 0) {
         lambda /= bStd
       }
       aaValues(i) += lambda

http://git-wip-us.apache.org/repos/asf/spark/blob/9753835c/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
index b542ba3..0b58a98 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
 class WeightedLeastSquaresSuite extends SparkFunSuite with 
MLlibTestSparkContext {
 
   private var instances: RDD[Instance] = _
+  private var instancesConstLabel: RDD[Instance] = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()
@@ -43,6 +44,20 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with 
MLlibTestSparkContext
       Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
       Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
     ), 2)
+
+    /*
+       R code:
+
+       A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2)
+       b.const <- c(17, 17, 17, 17)
+       w <- c(1, 2, 3, 4)
+     */
+    instancesConstLabel = sc.parallelize(Seq(
+      Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+      Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)),
+      Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
+      Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
+    ), 2)
   }
 
   test("WLS against lm") {
@@ -65,15 +80,59 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with 
MLlibTestSparkContext
 
     var idx = 0
     for (fitIntercept <- Seq(false, true)) {
-      val wls = new WeightedLeastSquares(
-        fitIntercept, regParam = 0.0, standardizeFeatures = false, 
standardizeLabel = false)
-        .fit(instances)
-      val actual = Vectors.dense(wls.intercept, wls.coefficients(0), 
wls.coefficients(1))
-      assert(actual ~== expected(idx) absTol 1e-4)
+       for (standardization <- Seq(false, true)) {
+         val wls = new WeightedLeastSquares(
+           fitIntercept, regParam = 0.0, standardizeFeatures = standardization,
+           standardizeLabel = standardization).fit(instances)
+         val actual = Vectors.dense(wls.intercept, wls.coefficients(0), 
wls.coefficients(1))
+         assert(actual ~== expected(idx) absTol 1e-4)
+       }
+      idx += 1
+    }
+  }
+
+  test("WLS against lm when label is constant and no regularization") {
+    /*
+       R code:
+
+       df.const.label <- as.data.frame(cbind(A, b.const))
+       for (formula in c(b.const ~ . -1, b.const ~ .)) {
+         model <- lm(formula, data=df.const.label, weights=w)
+         print(as.vector(coef(model)))
+       }
+
+      [1] -9.221298  3.394343
+      [1] 17  0  0
+    */
+
+    val expected = Seq(
+      Vectors.dense(0.0, -9.221298, 3.394343),
+      Vectors.dense(17.0, 0.0, 0.0))
+
+    var idx = 0
+    for (fitIntercept <- Seq(false, true)) {
+      for (standardization <- Seq(false, true)) {
+        val wls = new WeightedLeastSquares(
+          fitIntercept, regParam = 0.0, standardizeFeatures = standardization,
+          standardizeLabel = standardization).fit(instancesConstLabel)
+        val actual = Vectors.dense(wls.intercept, wls.coefficients(0), 
wls.coefficients(1))
+        assert(actual ~== expected(idx) absTol 1e-4)
+      }
       idx += 1
     }
   }
 
+  test("WLS with regularization when label is constant") {
+    // if regParam is non-zero and standardization is true, the problem is 
ill-defined and
+    // an exception is thrown.
+    val wls = new WeightedLeastSquares(
+      fitIntercept = false, regParam = 0.1, standardizeFeatures = true,
+      standardizeLabel = true)
+    intercept[IllegalArgumentException]{
+      wls.fit(instancesConstLabel)
+    }
+  }
+
   test("WLS against glmnet") {
     /*
        R code:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to