Github user yanboliang commented on a diff in the pull request:

    https://github.com/apache/spark/pull/15394#discussion_r82980460
  
    --- Diff: 
mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala 
---
    @@ -132,24 +197,234 @@ class WeightedLeastSquaresSuite extends 
SparkFunSuite with MLlibTestSparkContext
         var idx = 0
         for (fitIntercept <- Seq(false, true)) {
           for (standardization <- Seq(false, true)) {
    -        val wls = new WeightedLeastSquares(
    -          fitIntercept, regParam = 0.0, standardizeFeatures = 
standardization,
    -          standardizeLabel = standardization).fit(instancesConstLabel)
    -        val actual = Vectors.dense(wls.intercept, wls.coefficients(0), 
wls.coefficients(1))
    -        assert(actual ~== expected(idx) absTol 1e-4)
    +        for (solver <- WeightedLeastSquares.supportedSolvers) {
    +          val wls = new WeightedLeastSquares(fitIntercept, regParam = 0.0, 
elasticNetParam = 0.0,
    +            standardizeFeatures = standardization,
    +            standardizeLabel = standardization, solverType = 
solver).fit(instancesConstLabel)
    +          val actual = Vectors.dense(wls.intercept, wls.coefficients(0), 
wls.coefficients(1))
    +          assert(actual ~== expected(idx) absTol 1e-4)
    +        }
           }
           idx += 1
         }
    +
    +    // when label is constant zero, and fitIntercept is false, we should 
not train and get all zeros
    +    val instancesConstZeroLabel = instancesConstLabel.map { case 
Instance(l, w, f) =>
    +      Instance(0.0, w, f)
    +    }
    +    for (solver <- WeightedLeastSquares.supportedSolvers) {
    +      val wls = new WeightedLeastSquares(false, 0.0, 0.0, true, true, 
solverType = solver)
    +        .fit(instancesConstZeroLabel)
    +      val actual = Vectors.dense(wls.intercept, wls.coefficients(0), 
wls.coefficients(1))
    +      assert(actual === Vectors.dense(0.0, 0.0, 0.0))
    +      assert(wls.objectiveHistory === Array(0.0))
    +    }
       }
     
       test("WLS with regularization when label is constant") {
         // if regParam is non-zero and standardization is true, the problem is 
ill-defined and
         // an exception is thrown.
    -    val wls = new WeightedLeastSquares(
    -      fitIntercept = false, regParam = 0.1, standardizeFeatures = true,
    -      standardizeLabel = true)
    -    intercept[IllegalArgumentException]{
    -      wls.fit(instancesConstLabel)
    +    for (solver <- WeightedLeastSquares.supportedSolvers) {
    +      val wls = new WeightedLeastSquares(
    +        fitIntercept = false, regParam = 0.1, elasticNetParam = 0.0, 
standardizeFeatures = true,
    +        standardizeLabel = true, solverType = solver)
    +      intercept[IllegalArgumentException]{
    +        wls.fit(instancesConstLabel)
    +      }
    +    }
    +  }
    +
    +  test("WLS against glmnet with constant features") {
    +    /*
    +       R code:
    +
    +       A <- matrix(c(1, 1, 1, 1, 5, 7, 11, 13), 4, 2)
    +       b <- c(17, 19, 23, 29)
    +       w <- c(1, 2, 3, 4)
    +     */
    +    val constantFeatures = sc.parallelize(Seq(
    --- End diff --
    
    It's better to move the dataset as a member variable of the class like 
```instancesConstLabel```, then other test cases can leverage it later.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to