Github user yanboliang commented on a diff in the pull request: https://github.com/apache/spark/pull/15394#discussion_r82980129 --- Diff: mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala --- @@ -132,24 +197,234 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext var idx = 0 for (fitIntercept <- Seq(false, true)) { for (standardization <- Seq(false, true)) { - val wls = new WeightedLeastSquares( - fitIntercept, regParam = 0.0, standardizeFeatures = standardization, - standardizeLabel = standardization).fit(instancesConstLabel) - val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) - assert(actual ~== expected(idx) absTol 1e-4) + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares(fitIntercept, regParam = 0.0, elasticNetParam = 0.0, + standardizeFeatures = standardization, + standardizeLabel = standardization, solverType = solver).fit(instancesConstLabel) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual ~== expected(idx) absTol 1e-4) + } } idx += 1 } + + // when label is constant zero, and fitIntercept is false, we should not train and get all zeros + val instancesConstZeroLabel = instancesConstLabel.map { case Instance(l, w, f) => + Instance(0.0, w, f) + } + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares(false, 0.0, 0.0, true, true, solverType = solver) + .fit(instancesConstZeroLabel) + val actual = Vectors.dense(wls.intercept, wls.coefficients(0), wls.coefficients(1)) + assert(actual === Vectors.dense(0.0, 0.0, 0.0)) + assert(wls.objectiveHistory === Array(0.0)) + } } test("WLS with regularization when label is constant") { // if regParam is non-zero and standardization is true, the problem is ill-defined and // an exception is thrown. - val wls = new WeightedLeastSquares( - fitIntercept = false, regParam = 0.1, standardizeFeatures = true, - standardizeLabel = true) - intercept[IllegalArgumentException]{ - wls.fit(instancesConstLabel) + for (solver <- WeightedLeastSquares.supportedSolvers) { + val wls = new WeightedLeastSquares( + fitIntercept = false, regParam = 0.1, elasticNetParam = 0.0, standardizeFeatures = true, + standardizeLabel = true, solverType = solver) + intercept[IllegalArgumentException]{ + wls.fit(instancesConstLabel) + } + } + } + + test("WLS against glmnet with constant features") { + /* + R code: + + A <- matrix(c(1, 1, 1, 1, 5, 7, 11, 13), 4, 2) + b <- c(17, 19, 23, 29) + w <- c(1, 2, 3, 4) + */ + val constantFeatures = sc.parallelize(Seq( + Instance(17.0, 1.0, Vectors.dense(1.0, 5.0)), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(1.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(1.0, 13.0)) + ), 2) + + // Cholesky solver does not handle singular input with no regularization + for (fitIntercept <- Seq(false, true); + standardization <- Seq(false, true)) { + val wls = new WeightedLeastSquares(fitIntercept, 0.0, 0.0, standardization, standardization, + solverType = WeightedLeastSquares.Cholesky) + // for the case of no intercept, this would not have failed before but since we train + // in the standardized space now, it will fail + intercept[SingularMatrixException] { + wls.fit(constantFeatures) + } + } + + // should not fail when regularization is added + new WeightedLeastSquares(true, 0.5, 0.0, standardizeFeatures = true, + standardizeLabel = true, solverType = WeightedLeastSquares.Cholesky).fit(constantFeatures) + + /* + for (intercept in c(FALSE, TRUE)) { + for (standardize in c(FALSE, TRUE)) { + for (regParams in list(c(0.0, 0.0), c(0.5, 0.0), c(0.5, 0.5), c(0.5, 1.0))) { + model <- glmnet(A, b, weights=w, intercept=intercept, lambda=regParams[1], + standardize=standardize, alpha=regParams[2], thresh=1E-14) + print(as.vector(coef(model))) + } + } + } + [1] 0.000000 0.000000 2.253012 + [1] 0.000000 0.000000 2.250857 + [1] 0.000000 0.000000 2.249784 + [1] 0.000000 0.000000 2.248709 + [1] 0.000000 0.000000 2.253012 + [1] 0.000000 0.000000 2.235802 + [1] 0.000000 0.000000 2.238297 + [1] 0.000000 0.000000 2.240811 + [1] 8.218905 0.000000 1.517413 + [1] 8.434286 0.000000 1.496703 + [1] 8.648497 0.000000 1.476106 + [1] 8.865672 0.000000 1.455224 + [1] 8.218905 0.000000 1.517413 + [1] 9.798771 0.000000 1.365503 + [1] 9.919095 0.000000 1.353933 + [1] 10.052804 0.000000 1.341077 + */ + val expectedQuasiNewton = Seq( + Vectors.dense(0.000000, 0.000000, 2.253012), + Vectors.dense(0.000000, 0.000000, 2.250857), + Vectors.dense(0.000000, 0.000000, 2.249784), + Vectors.dense(0.000000, 0.000000, 2.248709), + Vectors.dense(0.000000, 0.000000, 2.253012), + Vectors.dense(0.000000, 0.000000, 2.235802), + Vectors.dense(0.000000, 0.000000, 2.238297), + Vectors.dense(0.000000, 0.000000, 2.240811), + Vectors.dense(8.218905, 0.000000, 1.517413), + Vectors.dense(8.434286, 0.000000, 1.496703), + Vectors.dense(8.648497, 0.000000, 1.476106), + Vectors.dense(8.865672, 0.000000, 1.455224), + Vectors.dense(8.218905, 0.000000, 1.517413), + Vectors.dense(9.798771, 0.000000, 1.365503), + Vectors.dense(9.919095, 0.000000, 1.353933), + Vectors.dense(10.052804, 0.000000, 1.341077)) + var idx = 0 + for (fitIntercept <- Seq(false, true); + standardization <- Seq(false, true); + (lambda, alpha) <- Seq((0.0, 0.0), (0.5, 0.0), (0.5, 0.5), (0.5, 1.0))) { + val wls = new WeightedLeastSquares(fitIntercept, regParam = lambda, elasticNetParam = alpha, + standardizeFeatures = standardization, standardizeLabel = true, + solverType = WeightedLeastSquares.QuasiNewton) + val model = wls.fit(constantFeatures) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~== expectedQuasiNewton(idx) absTol 1e-6) + idx += 1 + } + } + + test("WLS against glmnet with L1 regularization") { + /* --- End diff -- Add ```library(glmnet)```
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org