Repository: spark
Updated Branches:
  refs/heads/master e8e0fd691 -> f6a189930


Streaming mllib [SPARK-2438][MLLIB]

This PR implements a streaming linear regression analysis, in which a linear 
regression model is trained online as new data arrive. The design is based on 
discussions with tdas and mengxr, in which we determined how to add this 
functionality in a general way, with minimal changes to existing libraries.

__Summary of additions:__

_StreamingLinearAlgorithm_
- An abstract class for fitting generalized linear models online to streaming 
data, including training on (and updating) a model, and making predictions.

_StreamingLinearRegressionWithSGD_
- Class and companion object for running streaming linear regression

_StreamingLinearRegressionTestSuite_
- Unit tests

_StreamingLinearRegression_
- Example use case: fitting a model online to data from one stream, and making 
predictions on other data

__Notes__
- If this looks good, I can use the StreamingLinearAlgorithm class to easily 
implement other analyses that follow the same logic (Ridge, Lasso, Logistic, 
SVM).

Author: Jeremy Freeman <the.freeman....@gmail.com>
Author: freeman <the.freeman....@gmail.com>

Closes #1361 from freeman-lab/streaming-mllib and squashes the following 
commits:

775ea29 [Jeremy Freeman] Throw error if user doesn't initialize weights
4086fee [Jeremy Freeman] Fixed current weight formatting
8b95b27 [Jeremy Freeman] Restored broadcasting
29f27ec [Jeremy Freeman] Formatting
8711c41 [Jeremy Freeman] Used return to avoid indentation
777b596 [Jeremy Freeman] Restored treeAggregate
74cf440 [Jeremy Freeman] Removed static methods
d28cf9a [Jeremy Freeman] Added usage notes
c3326e7 [Jeremy Freeman] Improved documentation
9541a41 [Jeremy Freeman] Merge remote-tracking branch 'upstream/master' into 
streaming-mllib
66eba5e [Jeremy Freeman] Fixed line lengths
2fe0720 [Jeremy Freeman] Minor cleanup
7d51378 [Jeremy Freeman] Moved streaming loader to MLUtils
b9b69f6 [Jeremy Freeman] Added setter methods
c3f8b5a [Jeremy Freeman] Modified logging
00aafdc [Jeremy Freeman] Add modifiers
14b801e [Jeremy Freeman] Name changes
c7d38a3 [Jeremy Freeman] Move check for empty data to GradientDescent
4b0a5d3 [Jeremy Freeman] Cleaned up tests
74188d6 [Jeremy Freeman] Eliminate dependency on commons
50dd237 [Jeremy Freeman] Removed experimental tag
6bfe1e6 [Jeremy Freeman] Fixed imports
a2a63ad [freeman] Makes convergence test more robust
86220bc [freeman] Streaming linear regression unit tests
fb4683a [freeman] Minor changes for scalastyle consistency
fd31e03 [freeman] Changed logging behavior
453974e [freeman] Fixed indentation
c4b1143 [freeman] Streaming linear regression
604f4d7 [freeman] Expanded private class to include mllib
d99aa85 [freeman] Helper methods for streaming MLlib apps
0898add [freeman] Added dependency on streaming


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f6a18993
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f6a18993
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f6a18993

Branch: refs/heads/master
Commit: f6a1899306c5ad766fea122d3ab4b83436d9f6fd
Parents: e8e0fd6
Author: Jeremy Freeman <the.freeman....@gmail.com>
Authored: Fri Aug 1 20:10:26 2014 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Fri Aug 1 20:10:26 2014 -0700

----------------------------------------------------------------------
 .../mllib/StreamingLinearRegression.scala       |  73 ++++++++++
 mllib/pom.xml                                   |   5 +
 .../mllib/optimization/GradientDescent.scala    |   9 ++
 .../mllib/regression/LinearRegression.scala     |   4 +-
 .../regression/StreamingLinearAlgorithm.scala   | 106 +++++++++++++++
 .../StreamingLinearRegressionWithSGD.scala      |  88 ++++++++++++
 .../org/apache/spark/mllib/util/MLUtils.scala   |  15 +++
 .../StreamingLinearRegressionSuite.scala        | 135 +++++++++++++++++++
 8 files changed, 433 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f6a18993/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala
 
b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala
new file mode 100644
index 0000000..1fd37ed
--- /dev/null
+++ 
b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.{Seconds, StreamingContext}
+
+/**
+ * Train a linear regression model on one stream of data and make predictions
+ * on another stream, where the data streams arrive as text files
+ * into two different directories.
+ *
+ * The rows of the text files must be labeled data points in the form
+ * `(y,[x1,x2,x3,...,xn])`
+ * Where n is the number of features. n must be the same for train and test.
+ *
+ * Usage: StreamingLinearRegression <trainingDir> <testDir> <batchDuration> 
<numFeatures>
+ *
+ * To run on your local machine using the two directories `trainingDir` and 
`testDir`,
+ * with updates every 5 seconds, and 2 features per data point, call:
+ *    $ bin/run-example \
+ *        org.apache.spark.examples.mllib.StreamingLinearRegression 
trainingDir testDir 5 2
+ *
+ * As you add text files to `trainingDir` the model will continuously update.
+ * Anytime you add text files to `testDir`, you'll see predictions from the 
current model.
+ *
+ */
+object StreamingLinearRegression {
+
+  def main(args: Array[String]) {
+
+    if (args.length != 4) {
+      System.err.println(
+        "Usage: StreamingLinearRegression <trainingDir> <testDir> 
<batchDuration> <numFeatures>")
+      System.exit(1)
+    }
+
+    val conf = new 
SparkConf().setMaster("local").setAppName("StreamingLinearRegression")
+    val ssc = new StreamingContext(conf, Seconds(args(2).toLong))
+
+    val trainingData = MLUtils.loadStreamingLabeledPoints(ssc, args(0))
+    val testData = MLUtils.loadStreamingLabeledPoints(ssc, args(1))
+
+    val model = new StreamingLinearRegressionWithSGD()
+      .setInitialWeights(Vectors.dense(Array.fill[Double](args(3).toInt)(0)))
+
+    model.trainOn(trainingData)
+    model.predictOn(testData).print()
+
+    ssc.start()
+    ssc.awaitTermination()
+
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6a18993/mllib/pom.xml
----------------------------------------------------------------------
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 45046ec..9a33bd1 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -41,6 +41,11 @@
       <version>${project.version}</version>
     </dependency>
     <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-streaming_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+    </dependency>
+    <dependency>
       <groupId>org.eclipse.jetty</groupId>
       <artifactId>jetty-server</artifactId>
     </dependency>

http://git-wip-us.apache.org/repos/asf/spark/blob/f6a18993/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
index 356aa94..a691205 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala
@@ -162,6 +162,14 @@ object GradientDescent extends Logging {
     val numExamples = data.count()
     val miniBatchSize = numExamples * miniBatchFraction
 
+    // if no data, return initial weights to avoid NaNs
+    if (numExamples == 0) {
+
+      logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no 
data found")
+      return (initialWeights, stochasticLossHistory.toArray)
+
+    }
+
     // Initialize weights as a column vector
     var weights = Vectors.dense(initialWeights.toArray)
     val n = weights.size
@@ -202,5 +210,6 @@ object GradientDescent extends Logging {
       stochasticLossHistory.takeRight(10).mkString(", ")))
 
     (weights, stochasticLossHistory.toArray)
+
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6a18993/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 8c078ec..81b6598 100644
--- 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -49,7 +49,7 @@ class LinearRegressionModel (
  * its corresponding right hand side label y.
  * See also the documentation for the precise formulation.
  */
-class LinearRegressionWithSGD private (
+class LinearRegressionWithSGD private[mllib] (
     private var stepSize: Double,
     private var numIterations: Int,
     private var miniBatchFraction: Double)
@@ -68,7 +68,7 @@ class LinearRegressionWithSGD private (
    */
   def this() = this(1.0, 100, 1.0)
 
-  override protected def createModel(weights: Vector, intercept: Double) = {
+  override protected[mllib] def createModel(weights: Vector, intercept: 
Double) = {
     new LinearRegressionModel(weights, intercept)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6a18993/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
new file mode 100644
index 0000000..b8b0b42
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.regression
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.Logging
+import org.apache.spark.streaming.dstream.DStream
+
+/**
+ * :: DeveloperApi ::
+ * StreamingLinearAlgorithm implements methods for continuously
+ * training a generalized linear model model on streaming data,
+ * and using it for prediction on (possibly different) streaming data.
+ *
+ * This class takes as type parameters a GeneralizedLinearModel,
+ * and a GeneralizedLinearAlgorithm, making it easy to extend to construct
+ * streaming versions of any analyses using GLMs.
+ * Initial weights must be set before calling trainOn or predictOn.
+ * Only weights will be updated, not an intercept. If the model needs
+ * an intercept, it should be manually appended to the input data.
+ *
+ * For example usage, see `StreamingLinearRegressionWithSGD`.
+ *
+ * NOTE(Freeman): In some use cases, the order in which trainOn and predictOn
+ * are called in an application will affect the results. When called on
+ * the same DStream, if trainOn is called before predictOn, when new data
+ * arrive the model will update and the prediction will be based on the new
+ * model. Whereas if predictOn is called first, the prediction will use the 
model
+ * from the previous update.
+ *
+ * NOTE(Freeman): It is ok to call predictOn repeatedly on multiple streams; 
this
+ * will generate predictions for each one all using the current model.
+ * It is also ok to call trainOn on different streams; this will update
+ * the model using each of the different sources, in sequence.
+ *
+ */
+@DeveloperApi
+abstract class StreamingLinearAlgorithm[
+    M <: GeneralizedLinearModel,
+    A <: GeneralizedLinearAlgorithm[M]] extends Logging {
+
+  /** The model to be updated and used for prediction. */
+  protected var model: M
+
+  /** The algorithm to use for updating. */
+  protected val algorithm: A
+
+  /** Return the latest model. */
+  def latestModel(): M = {
+    model
+  }
+
+  /**
+   * Update the model by training on batches of data from a DStream.
+   * This operation registers a DStream for training the model,
+   * and updates the model based on every subsequent
+   * batch of data from the stream.
+   *
+   * @param data DStream containing labeled data
+   */
+  def trainOn(data: DStream[LabeledPoint]) {
+    if (Option(model.weights) == None) {
+      logError("Initial weights must be set before starting training")
+      throw new IllegalArgumentException
+    }
+    data.foreachRDD { (rdd, time) =>
+        model = algorithm.run(rdd, model.weights)
+        logInfo("Model updated at time %s".format(time.toString))
+        val display = model.weights.size match {
+          case x if x > 100 => model.weights.toArray.take(100).mkString("[", 
",", "...")
+          case _ => model.weights.toArray.mkString("[", ",", "]")
+        }
+        logInfo("Current model: weights, %s".format (display))
+    }
+  }
+
+  /**
+   * Use the model to make predictions on batches of data from a DStream
+   *
+   * @param data DStream containing labeled data
+   * @return DStream containing predictions
+   */
+  def predictOn(data: DStream[LabeledPoint]): DStream[Double] = {
+    if (Option(model.weights) == None) {
+      logError("Initial weights must be set before starting prediction")
+      throw new IllegalArgumentException
+    }
+    data.map(x => model.predict(x.features))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6a18993/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
new file mode 100644
index 0000000..8851097
--- /dev/null
+++ 
b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.regression
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+
+/**
+ * Train or predict a linear regression model on streaming data. Training uses
+ * Stochastic Gradient Descent to update the model based on each new batch of
+ * incoming data from a DStream (see `LinearRegressionWithSGD` for model 
equation)
+ *
+ * Each batch of data is assumed to be an RDD of LabeledPoints.
+ * The number of data points per batch can vary, but the number
+ * of features must be constant. An initial weight
+ * vector must be provided.
+ *
+ * Use a builder pattern to construct a streaming linear regression
+ * analysis in an application, like:
+ *
+ *  val model = new StreamingLinearRegressionWithSGD()
+ *    .setStepSize(0.5)
+ *    .setNumIterations(10)
+ *    .setInitialWeights(Vectors.dense(...))
+ *    .trainOn(DStream)
+ *
+ */
+@Experimental
+class StreamingLinearRegressionWithSGD (
+    private var stepSize: Double,
+    private var numIterations: Int,
+    private var miniBatchFraction: Double,
+    private var initialWeights: Vector)
+  extends StreamingLinearAlgorithm[
+    LinearRegressionModel, LinearRegressionWithSGD] with Serializable {
+
+  /**
+   * Construct a StreamingLinearRegression object with default parameters:
+   * {stepSize: 0.1, numIterations: 50, miniBatchFraction: 1.0}.
+   * Initial weights must be set before using trainOn or predictOn
+   * (see `StreamingLinearAlgorithm`)
+   */
+  def this() = this(0.1, 50, 1.0, null)
+
+  val algorithm = new LinearRegressionWithSGD(stepSize, numIterations, 
miniBatchFraction)
+
+  var model = algorithm.createModel(initialWeights, 0.0)
+
+  /** Set the step size for gradient descent. Default: 0.1. */
+  def setStepSize(stepSize: Double): this.type = {
+    this.algorithm.optimizer.setStepSize(stepSize)
+    this
+  }
+
+  /** Set the number of iterations of gradient descent to run per update. 
Default: 50. */
+  def setNumIterations(numIterations: Int): this.type = {
+    this.algorithm.optimizer.setNumIterations(numIterations)
+    this
+  }
+
+  /** Set the fraction of each batch to use for updates. Default: 1.0. */
+  def setMiniBatchFraction(miniBatchFraction: Double): this.type = {
+    this.algorithm.optimizer.setMiniBatchFraction(miniBatchFraction)
+    this
+  }
+
+  /** Set the initial weights. Default: [0.0, 0.0]. */
+  def setInitialWeights(initialWeights: Vector): this.type = {
+    this.model = algorithm.createModel(initialWeights, 0.0)
+    this
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6a18993/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index dc10a19..f4cce86 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -30,6 +30,8 @@ import org.apache.spark.util.random.BernoulliSampler
 import org.apache.spark.mllib.regression.{LabeledPointParser, LabeledPoint}
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.StreamingContext
+import org.apache.spark.streaming.dstream.DStream
 
 /**
  * Helper methods to load, save and pre-process data used in ML Lib.
@@ -193,6 +195,19 @@ object MLUtils {
     loadLabeledPoints(sc, dir, sc.defaultMinPartitions)
 
   /**
+   * Loads streaming labeled points from a stream of text files
+   * where points are in the same format as used in 
`RDD[LabeledPoint].saveAsTextFile`.
+   * See `StreamingContext.textFileStream` for more details on how to
+   * generate a stream from files
+   *
+   * @param ssc Streaming context
+   * @param dir Directory path in any Hadoop-supported file system URI
+   * @return Labeled points stored as a DStream[LabeledPoint]
+   */
+  def loadStreamingLabeledPoints(ssc: StreamingContext, dir: String): 
DStream[LabeledPoint] =
+    ssc.textFileStream(dir).map(LabeledPointParser.parse)
+
+  /**
    * Load labeled data from a file. The data format used here is
    * <L>, <f1> <f2> ...
    * where <f1>, <f2> are feature values in Double and <L> is the 
corresponding label as Double.

http://git-wip-us.apache.org/repos/asf/spark/blob/f6a18993/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
new file mode 100644
index 0000000..ed21f84
--- /dev/null
+++ 
b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.regression
+
+import java.io.File
+import java.nio.charset.Charset
+
+import scala.collection.mutable.ArrayBuffer
+
+import com.google.common.io.Files
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext, 
MLUtils}
+import org.apache.spark.streaming.{Milliseconds, StreamingContext}
+import org.apache.spark.util.Utils
+
+class StreamingLinearRegressionSuite extends FunSuite with LocalSparkContext {
+
+  // Assert that two values are equal within tolerance epsilon
+  def assertEqual(v1: Double, v2: Double, epsilon: Double) {
+    def errorMessage = v1.toString + " did not equal " + v2.toString
+    assert(math.abs(v1-v2) <= epsilon, errorMessage)
+  }
+
+  // Assert that model predictions are correct
+  def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+    val numOffPredictions = predictions.zip(input).count { case (prediction, 
expected) =>
+      // A prediction is off if the prediction is more than 0.5 away from 
expected value.
+      math.abs(prediction - expected.label) > 0.5
+    }
+    // At least 80% of the predictions should be on.
+    assert(numOffPredictions < input.length / 5)
+  }
+
+  // Test if we can accurately learn Y = 10*X1 + 10*X2 on streaming data
+  test("streaming linear regression parameter accuracy") {
+
+    val testDir = Files.createTempDir()
+    val numBatches = 10
+    val batchDuration = Milliseconds(1000)
+    val ssc = new StreamingContext(sc, batchDuration)
+    val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString)
+    val model = new StreamingLinearRegressionWithSGD()
+      .setInitialWeights(Vectors.dense(0.0, 0.0))
+      .setStepSize(0.1)
+      .setNumIterations(50)
+
+    model.trainOn(data)
+
+    ssc.start()
+
+    // write data to a file stream
+    for (i <- 0 until numBatches) {
+      val samples = LinearDataGenerator.generateLinearInput(
+        0.0, Array(10.0, 10.0), 100, 42 * (i + 1))
+      val file = new File(testDir, i.toString)
+      Files.write(samples.map(x => x.toString).mkString("\n"), file, 
Charset.forName("UTF-8"))
+      Thread.sleep(batchDuration.milliseconds)
+    }
+
+    ssc.stop(stopSparkContext=false)
+
+    System.clearProperty("spark.driver.port")
+    Utils.deleteRecursively(testDir)
+
+    // check accuracy of final parameter estimates
+    assertEqual(model.latestModel().intercept, 0.0, 0.1)
+    assertEqual(model.latestModel().weights(0), 10.0, 0.1)
+    assertEqual(model.latestModel().weights(1), 10.0, 0.1)
+
+    // check accuracy of predictions
+    val validationData = LinearDataGenerator.generateLinearInput(0.0, 
Array(10.0, 10.0), 100, 17)
+    validatePrediction(validationData.map(row => 
model.latestModel().predict(row.features)),
+      validationData)
+  }
+
+  // Test that parameter estimates improve when learning Y = 10*X1 on 
streaming data
+  test("streaming linear regression parameter convergence") {
+
+    val testDir = Files.createTempDir()
+    val batchDuration = Milliseconds(2000)
+    val ssc = new StreamingContext(sc, batchDuration)
+    val numBatches = 5
+    val data = MLUtils.loadStreamingLabeledPoints(ssc, testDir.toString)
+    val model = new StreamingLinearRegressionWithSGD()
+      .setInitialWeights(Vectors.dense(0.0))
+      .setStepSize(0.1)
+      .setNumIterations(50)
+
+    model.trainOn(data)
+
+    ssc.start()
+
+    // write data to a file stream
+    val history = new ArrayBuffer[Double](numBatches)
+    for (i <- 0 until numBatches) {
+      val samples = LinearDataGenerator.generateLinearInput(0.0, Array(10.0), 
100, 42 * (i + 1))
+      val file = new File(testDir, i.toString)
+      Files.write(samples.map(x => x.toString).mkString("\n"), file, 
Charset.forName("UTF-8"))
+      Thread.sleep(batchDuration.milliseconds)
+      // wait an extra few seconds to make sure the update finishes before new 
data arrive
+      Thread.sleep(4000)
+      history.append(math.abs(model.latestModel().weights(0) - 10.0))
+    }
+
+    ssc.stop(stopSparkContext=false)
+
+    System.clearProperty("spark.driver.port")
+    Utils.deleteRecursively(testDir)
+
+    val deltas = history.drop(1).zip(history.dropRight(1))
+    // check error stability (it always either shrinks, or increases with 
small tol)
+    assert(deltas.forall(x => (x._1 - x._2) <= 0.1))
+    // check that error shrunk on at least 2 batches
+    assert(deltas.map(x => if ((x._1 - x._2) < 0) 1 else 0).sum > 1)
+
+  }
+
+}

Reply via email to