Repository: spark
Updated Branches:
  refs/heads/master ca9fe540f -> 331f0b10f


[SPARK-9642] [ML] LinearRegression should supported weighted data

In many modeling application, data points are not necessarily sampled with 
equal probabilities. Linear regression should support weighting which account 
the over or under sampling.

work in progress.

Author: Meihua Wu <meihu...@umich.edu>

Closes #8631 from rotationsymmetry/SPARK-9642.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/331f0b10
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/331f0b10
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/331f0b10

Branch: refs/heads/master
Commit: 331f0b10f78a37d96d3e573d211d74a0935265db
Parents: ca9fe54
Author: Meihua Wu <meihu...@umich.edu>
Authored: Mon Sep 21 12:09:00 2015 -0700
Committer: DB Tsai <d...@netflix.com>
Committed: Mon Sep 21 12:09:00 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/regression/LinearRegression.scala  | 164 +++++++++++--------
 .../ml/regression/LinearRegressionSuite.scala   |  88 ++++++++++
 project/MimaExcludes.scala                      |   8 +-
 3 files changed, 191 insertions(+), 69 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/331f0b10/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index e4602d3..78a67c5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -31,21 +31,29 @@ import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.mllib.evaluation.RegressionMetrics
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS._
-import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.functions.{col, udf}
-import org.apache.spark.sql.types.StructField
+import org.apache.spark.sql.functions.{col, udf, lit}
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.StatCounter
 
 /**
  * Params for linear regression.
  */
 private[regression] trait LinearRegressionParams extends PredictorParams
     with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
-    with HasFitIntercept with HasStandardization
+    with HasFitIntercept with HasStandardization with HasWeightCol
+
+/**
+ * Class that represents an instance of weighted data point with label and 
features.
+ *
+ * TODO: Refactor this class to proper place.
+ *
+ * @param label Label for this data point.
+ * @param weight The weight of this instance.
+ * @param features The vector of features for this data point.
+ */
+private[regression] case class Instance(label: Double, weight: Double, 
features: Vector)
 
 /**
  * :: Experimental ::
@@ -123,30 +131,43 @@ class LinearRegression(override val uid: String)
   def setTol(value: Double): this.type = set(tol, value)
   setDefault(tol -> 1E-6)
 
+  /**
+   * Whether to over-/under-sample training instances according to the given 
weights in weightCol.
+   * If empty, all instances are treated equally (weight 1.0).
+   * Default is empty, so all instances have weight one.
+   * @group setParam
+   */
+  def setWeightCol(value: String): this.type = set(weightCol, value)
+  setDefault(weightCol -> "")
+
   override protected def train(dataset: DataFrame): LinearRegressionModel = {
     // Extract columns from data.  If dataset is persisted, do not persist 
instances.
-    val instances = extractLabeledPoints(dataset).map {
-      case LabeledPoint(label: Double, features: Vector) => (label, features)
+    val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+    val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, 
col($(featuresCol))).map {
+      case Row(label: Double, weight: Double, features: Vector) =>
+        Instance(label, weight, features)
     }
+
     val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
     if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
 
-    val (summarizer, statCounter) = instances.treeAggregate(
-      (new MultivariateOnlineSummarizer, new StatCounter))(
-        seqOp = (c, v) => (c, v) match {
-          case ((summarizer: MultivariateOnlineSummarizer, statCounter: 
StatCounter),
-          (label: Double, features: Vector)) =>
-            (summarizer.add(features), statCounter.merge(label))
-      },
-        combOp = (c1, c2) => (c1, c2) match {
-          case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: 
StatCounter),
-          (summarizer2: MultivariateOnlineSummarizer, statCounter2: 
StatCounter)) =>
-            (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2))
-      })
-
-    val numFeatures = summarizer.mean.size
-    val yMean = statCounter.mean
-    val yStd = math.sqrt(statCounter.variance)
+    val (featuresSummarizer, ySummarizer) = {
+      val seqOp = (c: (MultivariateOnlineSummarizer, 
MultivariateOnlineSummarizer),
+        instance: Instance) =>
+          (c._1.add(instance.features, instance.weight),
+            c._2.add(Vectors.dense(instance.label), instance.weight))
+
+      val combOp = (c1: (MultivariateOnlineSummarizer, 
MultivariateOnlineSummarizer),
+        c2: (MultivariateOnlineSummarizer, MultivariateOnlineSummarizer)) =>
+          (c1._1.merge(c2._1), c1._2.merge(c2._2))
+
+      instances.treeAggregate(
+        new MultivariateOnlineSummarizer, new 
MultivariateOnlineSummarizer)(seqOp, combOp)
+    }
+
+    val numFeatures = featuresSummarizer.mean.size
+    val yMean = ySummarizer.mean(0)
+    val yStd = math.sqrt(ySummarizer.variance(0))
 
     // If the yStd is zero, then the intercept is yMean with zero weights;
     // as a result, training is not needed.
@@ -167,8 +188,8 @@ class LinearRegression(override val uid: String)
       return copyValues(model.setSummary(trainingSummary))
     }
 
-    val featuresMean = summarizer.mean.toArray
-    val featuresStd = summarizer.variance.toArray.map(math.sqrt)
+    val featuresMean = featuresSummarizer.mean.toArray
+    val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
 
     // Since we implicitly do the feature scaling when we compute the cost 
function
     // to improve the convergence, the effective regParam will be changed.
@@ -318,7 +339,8 @@ class LinearRegressionModel private[ml] (
 
 /**
  * :: Experimental ::
- * Linear regression training results.
+ * Linear regression training results. Currently, the training summary ignores 
the
+ * training weights except for the objective trace.
  * @param predictions predictions outputted by the model's `transform` method.
  * @param objectiveHistory objective function (scaled loss + regularization) 
at each iteration.
  */
@@ -477,7 +499,7 @@ class LinearRegressionSummary private[regression] (
  * \frac{\partial L}{\partial\w_i} = 1/N ((\sum_j diff_j x_{ij} / \hat{x_i})
  * }}},
  *
- * @param weights The weights/coefficients corresponding to the features.
+ * @param coefficients The coefficients corresponding to the features.
  * @param labelStd The standard deviation value of the label.
  * @param labelMean The mean value of the label.
  * @param fitIntercept Whether to fit an intercept term.
@@ -485,7 +507,7 @@ class LinearRegressionSummary private[regression] (
  * @param featuresMean The mean values of the features.
  */
 private class LeastSquaresAggregator(
-    weights: Vector,
+    coefficients: Vector,
     labelStd: Double,
     labelMean: Double,
     fitIntercept: Boolean,
@@ -493,26 +515,28 @@ private class LeastSquaresAggregator(
     featuresMean: Array[Double]) extends Serializable {
 
   private var totalCnt: Long = 0L
+  private var weightSum: Double = 0.0
   private var lossSum = 0.0
 
-  private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) 
= {
-    val weightsArray = weights.toArray.clone()
+  private val (effectiveCoefficientsArray: Array[Double], offset: Double, dim: 
Int) = {
+    val coefficientsArray = coefficients.toArray.clone()
     var sum = 0.0
     var i = 0
-    val len = weightsArray.length
+    val len = coefficientsArray.length
     while (i < len) {
       if (featuresStd(i) != 0.0) {
-        weightsArray(i) /=  featuresStd(i)
-        sum += weightsArray(i) * featuresMean(i)
+        coefficientsArray(i) /=  featuresStd(i)
+        sum += coefficientsArray(i) * featuresMean(i)
       } else {
-        weightsArray(i) = 0.0
+        coefficientsArray(i) = 0.0
       }
       i += 1
     }
-    (weightsArray, if (fitIntercept) labelMean / labelStd - sum else 0.0, 
weightsArray.length)
+    val offset = if (fitIntercept) labelMean / labelStd - sum else 0.0
+    (coefficientsArray, offset, coefficientsArray.length)
   }
 
-  private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
+  private val effectiveCoefficientsVector = 
Vectors.dense(effectiveCoefficientsArray)
 
   private val gradientSumArray = Array.ofDim[Double](dim)
 
@@ -520,30 +544,33 @@ private class LeastSquaresAggregator(
    * Add a new training data to this LeastSquaresAggregator, and update the 
loss and gradient
    * of the objective function.
    *
-   * @param label The label for this data point.
-   * @param data The features for one data point in dense/sparse vector format 
to be added
-   *             into this aggregator.
+   * @param instance  The data point instance to be added.
    * @return This LeastSquaresAggregator object.
    */
-  def add(label: Double, data: Vector): this.type = {
-    require(dim == data.size, s"Dimensions mismatch when adding new sample." +
-      s" Expecting $dim but got ${data.size}.")
+  def add(instance: Instance): this.type =
+    instance match { case Instance(label, weight, features) =>
+      require(dim == features.size, s"Dimensions mismatch when adding new 
sample." +
+        s" Expecting $dim but got ${features.size}.")
+      require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
 
-    val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset
+      if (weight == 0.0) return this
 
-    if (diff != 0) {
-      val localGradientSumArray = gradientSumArray
-      data.foreachActive { (index, value) =>
-        if (featuresStd(index) != 0.0 && value != 0.0) {
-          localGradientSumArray(index) += diff * value / featuresStd(index)
+      val diff = dot(features, effectiveCoefficientsVector) - label / labelStd 
+ offset
+
+      if (diff != 0) {
+        val localGradientSumArray = gradientSumArray
+        features.foreachActive { (index, value) =>
+          if (featuresStd(index) != 0.0 && value != 0.0) {
+            localGradientSumArray(index) += weight * diff * value / 
featuresStd(index)
+          }
         }
+        lossSum += weight * diff * diff / 2.0
       }
-      lossSum += diff * diff / 2.0
-    }
 
-    totalCnt += 1
-    this
-  }
+      totalCnt += 1
+      weightSum += weight
+      this
+    }
 
   /**
    * Merge another LeastSquaresAggregator, and update the loss and gradient
@@ -557,8 +584,9 @@ private class LeastSquaresAggregator(
     require(dim == other.dim, s"Dimensions mismatch when merging with another 
" +
       s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.")
 
-    if (other.totalCnt != 0) {
+    if (other.weightSum != 0) {
       totalCnt += other.totalCnt
+      weightSum += other.weightSum
       lossSum += other.lossSum
 
       var i = 0
@@ -574,11 +602,17 @@ private class LeastSquaresAggregator(
 
   def count: Long = totalCnt
 
-  def loss: Double = lossSum / totalCnt
+  def loss: Double = {
+    require(weightSum > 0.0, s"The effective number of instances should be " +
+      s"greater than 0.0, but $weightSum.")
+    lossSum / weightSum
+  }
 
   def gradient: Vector = {
+    require(weightSum > 0.0, s"The effective number of instances should be " +
+      s"greater than 0.0, but $weightSum.")
     val result = Vectors.dense(gradientSumArray.clone())
-    scal(1.0 / totalCnt, result)
+    scal(1.0 / weightSum, result)
     result
   }
 }
@@ -589,7 +623,7 @@ private class LeastSquaresAggregator(
  * It's used in Breeze's convex optimization routines.
  */
 private class LeastSquaresCostFun(
-    data: RDD[(Double, Vector)],
+    data: RDD[Instance],
     labelStd: Double,
     labelMean: Double,
     fitIntercept: Boolean,
@@ -598,17 +632,13 @@ private class LeastSquaresCostFun(
     featuresMean: Array[Double],
     effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
 
-  override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
-    val w = Vectors.fromBreeze(weights)
+  override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
+    val coeff = Vectors.fromBreeze(coefficients)
 
-    val leastSquaresAggregator = data.treeAggregate(new 
LeastSquaresAggregator(w, labelStd,
+    val leastSquaresAggregator = data.treeAggregate(new 
LeastSquaresAggregator(coeff, labelStd,
       labelMean, fitIntercept, featuresStd, featuresMean))(
-        seqOp = (c, v) => (c, v) match {
-          case (aggregator, (label, features)) => aggregator.add(label, 
features)
-        },
-        combOp = (c1, c2) => (c1, c2) match {
-          case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
-        })
+        seqOp = (aggregator, instance) => aggregator.add(instance),
+        combOp = (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
 
     val totalGradientArray = leastSquaresAggregator.gradient.toArray
 
@@ -616,7 +646,7 @@ private class LeastSquaresCostFun(
       0.0
     } else {
       var sum = 0.0
-      w.foreachActive { (index, value) =>
+      coeff.foreachActive { (index, value) =>
         // The following code will compute the loss of the regularization; also
         // the gradient of the regularization, and add back to 
totalGradientArray.
         sum += {

http://git-wip-us.apache.org/repos/asf/spark/blob/331f0b10/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 2aaee71..8428f4f 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -17,10 +17,13 @@
 
 package org.apache.spark.ml.regression
 
+import scala.util.Random
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.MLTestingUtils
 import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.{DataFrame, Row}
@@ -510,4 +513,89 @@ class LinearRegressionSuite extends SparkFunSuite with 
MLlibTestSparkContext {
       .zip(testSummary.residuals.select("residuals").collect())
       .forall { case (Row(r1: Double), Row(r2: Double)) => r1 ~== r2 relTol 
1E-5 }
   }
+
+  test("linear regression with weighted samples"){
+    val (data, weightedData) = {
+      val activeData = LinearDataGenerator.generateLinearInput(
+        6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1)
+
+      val rnd = new Random(8392)
+      val signedData = activeData.map { case p: LabeledPoint =>
+        (rnd.nextGaussian() > 0.0, p)
+      }
+
+      val data1 = signedData.flatMap {
+        case (true, p) => Iterator(p, p)
+        case (false, p) => Iterator(p)
+      }
+
+      val weightedSignedData = signedData.flatMap {
+        case (true, LabeledPoint(label, features)) =>
+          Iterator(
+            Instance(label, weight = 1.2, features),
+            Instance(label, weight = 0.8, features)
+          )
+        case (false, LabeledPoint(label, features)) =>
+          Iterator(
+            Instance(label, weight = 0.3, features),
+            Instance(label, weight = 0.1, features),
+            Instance(label, weight = 0.6, features)
+          )
+      }
+
+      val noiseData = LinearDataGenerator.generateLinearInput(
+        2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1)
+      val weightedNoiseData = noiseData.map {
+        case LabeledPoint(label, features) => Instance(label, weight = 0, 
features)
+      }
+      val data2 = weightedSignedData ++ weightedNoiseData
+
+      (sqlContext.createDataFrame(sc.parallelize(data1, 4)),
+        sqlContext.createDataFrame(sc.parallelize(data2, 4)))
+    }
+
+    val trainer1a = (new LinearRegression).setFitIntercept(true)
+      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+    val trainer1b = (new 
LinearRegression).setFitIntercept(true).setWeightCol("weight")
+      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+    val model1a0 = trainer1a.fit(data)
+    val model1a1 = trainer1a.fit(weightedData)
+    val model1b = trainer1b.fit(weightedData)
+    assert(model1a0.weights !~= model1a1.weights absTol 1E-3)
+    assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3)
+    assert(model1a0.weights ~== model1b.weights absTol 1E-3)
+    assert(model1a0.intercept ~== model1b.intercept absTol 1E-3)
+
+    val trainer2a = (new LinearRegression).setFitIntercept(true)
+      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+    val trainer2b = (new 
LinearRegression).setFitIntercept(true).setWeightCol("weight")
+      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+    val model2a0 = trainer2a.fit(data)
+    val model2a1 = trainer2a.fit(weightedData)
+    val model2b = trainer2b.fit(weightedData)
+    assert(model2a0.weights !~= model2a1.weights absTol 1E-3)
+    assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3)
+    assert(model2a0.weights ~== model2b.weights absTol 1E-3)
+    assert(model2a0.intercept ~== model2b.intercept absTol 1E-3)
+
+    val trainer3a = (new LinearRegression).setFitIntercept(false)
+      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+    val trainer3b = (new 
LinearRegression).setFitIntercept(false).setWeightCol("weight")
+      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true)
+    val model3a0 = trainer3a.fit(data)
+    val model3a1 = trainer3a.fit(weightedData)
+    val model3b = trainer3b.fit(weightedData)
+    assert(model3a0.weights !~= model3a1.weights absTol 1E-3)
+    assert(model3a0.weights ~== model3b.weights absTol 1E-3)
+
+    val trainer4a = (new LinearRegression).setFitIntercept(false)
+      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+    val trainer4b = (new 
LinearRegression).setFitIntercept(false).setWeightCol("weight")
+      .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(false)
+    val model4a0 = trainer4a.fit(data)
+    val model4a1 = trainer4a.fit(weightedData)
+    val model4b = trainer4b.fit(weightedData)
+    assert(model4a0.weights !~= model4a1.weights absTol 1E-3)
+    assert(model4a0.weights ~== model4b.weights absTol 1E-3)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/331f0b10/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 814a11e..b2e6be7 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -70,10 +70,14 @@ object MimaExcludes {
           "org.apache.spark.scheduler.AskPermissionToCommitOutput.this"),
         ProblemFilters.exclude[IncompatibleMethTypeProblem](
           "org.apache.spark.scheduler.AskPermissionToCommitOutput.apply")
-      ) ++
-      Seq(
+      ) ++ Seq(
         ProblemFilters.exclude[MissingClassProblem](
           "org.apache.spark.shuffle.FileShuffleBlockResolver$ShuffleFileGroup")
+      ) ++ Seq(
+        ProblemFilters.exclude[MissingMethodProblem](
+          "org.apache.spark.ml.regression.LeastSquaresAggregator.add"),
+        ProblemFilters.exclude[MissingMethodProblem](
+          "org.apache.spark.ml.regression.LeastSquaresCostFun.this")
       )
     case v if v.startsWith("1.5") =>
       Seq(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to