Repository: spark Updated Branches: refs/heads/master 82c1c5772 -> 5f1cee6f1
[SPARK-11332] [ML] Refactored to use ml.feature.Instance instead of WeightedLeastSquare.Instance WeightedLeastSquares now uses the common Instance class in ml.feature instead of a private one. Author: Nakul Jindal <njin...@us.ibm.com> Closes #9325 from nakul02/SPARK-11332_refactor_WeightedLeastSquares_dot_Instance. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5f1cee6f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5f1cee6f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5f1cee6f Branch: refs/heads/master Commit: 5f1cee6f158adb1f9f485ed1d529c56bace68adc Parents: 82c1c57 Author: Nakul Jindal <njin...@us.ibm.com> Authored: Wed Oct 28 01:02:03 2015 -0700 Committer: DB Tsai <d...@netflix.com> Committed: Wed Oct 28 01:02:03 2015 -0700 ---------------------------------------------------------------------- .../spark/ml/optim/WeightedLeastSquares.scala | 25 +++++++------------- .../spark/ml/regression/LinearRegression.scala | 4 ++-- .../ml/optim/WeightedLeastSquaresSuite.scala | 10 ++++---- 3 files changed, 15 insertions(+), 24 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5f1cee6f/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index d7eaa5a..3d64f7f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.Logging +import org.apache.spark.ml.feature.Instance import org.apache.spark.mllib.linalg._ import org.apache.spark.rdd.RDD @@ -122,16 +123,6 @@ private[ml] class WeightedLeastSquares( private[ml] object WeightedLeastSquares { /** - * Case class for weighted observations. - * @param w weight, must be positive - * @param a features - * @param b label - */ - case class Instance(w: Double, a: Vector, b: Double) { - require(w >= 0.0, s"Weight cannot be negative: $w.") - } - - /** * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]]. */ // TODO: consolidate aggregates for summary statistics @@ -168,8 +159,8 @@ private[ml] object WeightedLeastSquares { * Adds an instance. */ def add(instance: Instance): this.type = { - val Instance(w, a, b) = instance - val ak = a.size + val Instance(l, w, f) = instance + val ak = f.size if (!initialized) { init(ak) } @@ -177,11 +168,11 @@ private[ml] object WeightedLeastSquares { count += 1L wSum += w wwSum += w * w - bSum += w * b - bbSum += w * b * b - BLAS.axpy(w, a, aSum) - BLAS.axpy(w * b, a, abSum) - BLAS.spr(w, a, aaSum) + bSum += w * l + bbSum += w * l * l + BLAS.axpy(w, f, aSum) + BLAS.axpy(w * l, f, abSum) + BLAS.spr(w, f, aaSum) this } http://git-wip-us.apache.org/repos/asf/spark/blob/5f1cee6f/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index c3ee8b3..f663b9b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -154,10 +154,10 @@ class LinearRegression(override val uid: String) "solver is used.'") // For low dimensional data, WeightedLeastSquares is more efficiently since the // training algorithm only requires one pass through the data. (SPARK-10668) - val instances: RDD[WeightedLeastSquares.Instance] = dataset.select( + val instances: RDD[Instance] = dataset.select( col($(labelCol)), w, col($(featuresCol))).map { case Row(label: Double, weight: Double, features: Vector) => - WeightedLeastSquares.Instance(weight, features, label) + Instance(label, weight, features) } val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam), http://git-wip-us.apache.org/repos/asf/spark/blob/5f1cee6f/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala index 652f3ad..b542ba3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.optim import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.optim.WeightedLeastSquares.Instance +import org.apache.spark.ml.feature.Instance import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -38,10 +38,10 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext w <- c(1, 2, 3, 4) */ instances = sc.parallelize(Seq( - Instance(1.0, Vectors.dense(0.0, 5.0).toSparse, 17.0), - Instance(2.0, Vectors.dense(1.0, 7.0), 19.0), - Instance(3.0, Vectors.dense(2.0, 11.0), 23.0), - Instance(4.0, Vectors.dense(3.0, 13.0), 29.0) + Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org