Repository: spark
Updated Branches:
  refs/heads/branch-2.1 0af94e772 -> 5f7a9af66


[SPARK-13027][STREAMING] Added batch time as a parameter to updateStateByKey

Added RDD batch time as an input parameter to the update function in 
updateStateByKey.

Author: Aaditya Ramesh <[email protected]>

Closes #11122 from aramesh117/SPARK-13027.

(cherry picked from commit 6f9e598ccf92f6272bbfb56ac56d3101387131b9)
Signed-off-by: Shixiong Zhu <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5f7a9af6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5f7a9af6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5f7a9af6

Branch: refs/heads/branch-2.1
Commit: 5f7a9af66c0c05225f175f36bc10016874fab6fc
Parents: 0af94e7
Author: Aaditya Ramesh <[email protected]>
Authored: Tue Nov 15 13:01:01 2016 -0800
Committer: Shixiong Zhu <[email protected]>
Committed: Tue Nov 15 13:01:08 2016 -0800

----------------------------------------------------------------------
 .../dstream/PairDStreamFunctions.scala          | 40 +++++++++---
 .../spark/streaming/dstream/StateDStream.scala  | 28 +++++----
 .../spark/streaming/BasicOperationsSuite.scala  | 66 ++++++++++++++++++++
 .../spark/streaming/DStreamClosureSuite.scala   | 12 ++++
 4 files changed, 126 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5f7a9af6/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 2f2a6d1..ac73941 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -453,9 +453,12 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
   def updateStateByKey[S: ClassTag](
       updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
       partitioner: Partitioner,
-      rememberPartitioner: Boolean
-    ): DStream[(K, S)] = ssc.withScope {
-     new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, 
rememberPartitioner, None)
+      rememberPartitioner: Boolean): DStream[(K, S)] = ssc.withScope {
+    val cleanedFunc = ssc.sc.clean(updateFunc)
+    val newUpdateFunc = (_: Time, it: Iterator[(K, Seq[V], Option[S])]) => {
+      cleanedFunc(it)
+    }
+    new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, 
None)
   }
 
   /**
@@ -499,10 +502,33 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
       updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
       partitioner: Partitioner,
       rememberPartitioner: Boolean,
-      initialRDD: RDD[(K, S)]
-    ): DStream[(K, S)] = ssc.withScope {
-     new StateDStream(self, ssc.sc.clean(updateFunc), partitioner,
-       rememberPartitioner, Some(initialRDD))
+      initialRDD: RDD[(K, S)]): DStream[(K, S)] = ssc.withScope {
+    val cleanedFunc = ssc.sc.clean(updateFunc)
+    val newUpdateFunc = (_: Time, it: Iterator[(K, Seq[V], Option[S])]) => {
+      cleanedFunc(it)
+    }
+    new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, 
Some(initialRDD))
+  }
+
+  /**
+   * Return a new "state" DStream where the state for each key is updated by 
applying
+   * the given function on the previous state of the key and the new values of 
the key.
+   * org.apache.spark.Partitioner is used to control the partitioning of each 
RDD.
+   * @param updateFunc State update function. If `this` function returns None, 
then
+   *                   corresponding state key-value pair will be eliminated.
+   * @param partitioner Partitioner for controlling the partitioning of each 
RDD in the new
+   *                    DStream.
+   * @tparam S State type
+   */
+  def updateStateByKey[S: ClassTag](updateFunc: (Time, K, Seq[V], Option[S]) 
=> Option[S],
+      partitioner: Partitioner,
+      rememberPartitioner: Boolean,
+      initialRDD: Option[RDD[(K, S)]] = None): DStream[(K, S)] = ssc.withScope 
{
+    val cleanedFunc = ssc.sc.clean(updateFunc)
+    val newUpdateFunc = (time: Time, iterator: Iterator[(K, Seq[V], 
Option[S])]) => {
+      iterator.flatMap(t => cleanedFunc(time, t._1, t._2, t._3).map(s => 
(t._1, s)))
+    }
+    new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, 
initialRDD)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/5f7a9af6/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
index 8efb09a..5bf1dab 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
@@ -27,7 +27,7 @@ import org.apache.spark.streaming.{Duration, Time}
 private[streaming]
 class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
     parent: DStream[(K, V)],
-    updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
+    updateFunc: (Time, Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
     partitioner: Partitioner,
     preservePartitioning: Boolean,
     initialRDD: Option[RDD[(K, S)]]
@@ -41,8 +41,10 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
 
   override val mustCheckpoint = true
 
-  private [this] def computeUsingPreviousRDD (
-    parentRDD: RDD[(K, V)], prevStateRDD: RDD[(K, S)]) = {
+  private [this] def computeUsingPreviousRDD(
+      batchTime: Time,
+      parentRDD: RDD[(K, V)],
+      prevStateRDD: RDD[(K, S)]) = {
     // Define the function for the mapPartition operation on cogrouped RDD;
     // first map the cogrouped tuple to tuples of required type,
     // and then apply the update function
@@ -53,7 +55,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
         val headOption = if (itr.hasNext) Some(itr.next()) else None
         (t._1, t._2._1.toSeq, headOption)
       }
-      updateFuncLocal(i)
+      updateFuncLocal(batchTime, i)
     }
     val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
     val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
@@ -68,15 +70,14 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
       case Some(prevStateRDD) =>    // If previous state RDD exists
         // Try to get the parent RDD
         parent.getOrCompute(validTime) match {
-          case Some(parentRDD) =>   // If parent RDD exists, then compute as 
usual
-            computeUsingPreviousRDD(parentRDD, prevStateRDD)
-          case None =>    // If parent RDD does not exist
-
+          case Some(parentRDD) =>    // If parent RDD exists, then compute as 
usual
+            computeUsingPreviousRDD (validTime, parentRDD, prevStateRDD)
+          case None =>     // If parent RDD does not exist
             // Re-apply the update function to the old state RDD
             val updateFuncLocal = updateFunc
             val finalFunc = (iterator: Iterator[(K, S)]) => {
               val i = iterator.map(t => (t._1, Seq[V](), Option(t._2)))
-              updateFuncLocal(i)
+              updateFuncLocal(validTime, i)
             }
             val stateRDD = prevStateRDD.mapPartitions(finalFunc, 
preservePartitioning)
             Some(stateRDD)
@@ -93,15 +94,16 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
                 // and then apply the update function
                 val updateFuncLocal = updateFunc
                 val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => {
-                  updateFuncLocal(iterator.map(tuple => (tuple._1, 
tuple._2.toSeq, None)))
+                  updateFuncLocal (validTime,
+                    iterator.map (tuple => (tuple._1, tuple._2.toSeq, None)))
                 }
 
                 val groupedRDD = parentRDD.groupByKey(partitioner)
                 val sessionRDD = groupedRDD.mapPartitions(finalFunc, 
preservePartitioning)
                 // logDebug("Generating state RDD for time " + validTime + " 
(first)")
-                Some(sessionRDD)
-              case Some(initialStateRDD) =>
-                computeUsingPreviousRDD(parentRDD, initialStateRDD)
+                Some (sessionRDD)
+              case Some (initialStateRDD) =>
+                computeUsingPreviousRDD(validTime, parentRDD, initialStateRDD)
             }
           case None => // If parent RDD does not exist, then nothing to do!
             // logDebug("Not generating state RDD (no previous state, no 
parent)")

http://git-wip-us.apache.org/repos/asf/spark/blob/5f7a9af6/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index cfcbdc7..4e702bb 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -471,6 +471,72 @@ class BasicOperationsSuite extends TestSuiteBase {
     testOperation(inputData, updateStateOperation, outputData, true)
   }
 
+  test("updateStateByKey - testing time stamps as input") {
+    type StreamingState = Long
+    val initial: Seq[(String, StreamingState)] = Seq(("a", 0L), ("c", 0L))
+
+    val inputData =
+      Seq(
+        Seq("a"),
+        Seq("a", "b"),
+        Seq("a", "b", "c"),
+        Seq("a", "b"),
+        Seq("a"),
+        Seq()
+      )
+
+    // a -> 1000, 3000, 6000, 10000, 15000, 15000
+    // b -> 0, 2000, 5000, 9000, 9000, 9000
+    // c -> 1000, 1000, 3000, 3000, 3000, 3000
+
+    val outputData: Seq[Seq[(String, StreamingState)]] = Seq(
+        Seq(
+          ("a", 1000L),
+          ("c", 0L)), // t = 1000
+        Seq(
+          ("a", 3000L),
+          ("b", 2000L),
+          ("c", 0L)), // t = 2000
+        Seq(
+          ("a", 6000L),
+          ("b", 5000L),
+          ("c", 3000L)), // t = 3000
+        Seq(
+          ("a", 10000L),
+          ("b", 9000L),
+          ("c", 3000L)), // t = 4000
+        Seq(
+          ("a", 15000L),
+          ("b", 9000L),
+          ("c", 3000L)), // t = 5000
+        Seq(
+          ("a", 15000L),
+          ("b", 9000L),
+          ("c", 3000L)) // t = 6000
+      )
+
+    val updateStateOperation = (s: DStream[String]) => {
+      val initialRDD = s.context.sparkContext.makeRDD(initial)
+      val updateFunc = (time: Time,
+                        key: String,
+                        values: Seq[Int],
+                        state: Option[StreamingState]) => {
+        // Update only if we receive values for this key during the batch.
+        if (values.nonEmpty) {
+          Option(time.milliseconds + state.getOrElse(0L))
+        } else {
+          Option(state.getOrElse(0L))
+        }
+      }
+      s.map(x => (x, 1)).updateStateByKey[StreamingState](updateFunc = 
updateFunc,
+        partitioner = new HashPartitioner (numInputPartitions), 
rememberPartitioner = false,
+        initialRDD = Option(initialRDD))
+    }
+
+    testOperation(input = inputData, operation = updateStateOperation,
+      expectedOutput = outputData, useSet = true)
+  }
+
   test("updateStateByKey - with initial value RDD") {
     val initial = Seq(("a", 1), ("c", 2))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5f7a9af6/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala 
b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
index 1fc34f5..2ab600a 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
@@ -164,6 +164,10 @@ class DStreamClosureSuite extends SparkFunSuite with 
BeforeAndAfterAll {
   private def testUpdateStateByKey(ds: DStream[(Int, Int)]): Unit = {
     val updateF1 = (_: Seq[Int], _: Option[Int]) => { return; Some(1) }
     val updateF2 = (_: Iterator[(Int, Seq[Int], Option[Int])]) => { return; 
Seq((1, 1)).toIterator }
+    val updateF3 = (_: Time, _: Int, _: Seq[Int], _: Option[Int]) => {
+      return
+      Option(1)
+    }
     val initialRDD = ds.ssc.sparkContext.emptyRDD[Int].map { i => (i, i) }
     expectCorrectException { ds.updateStateByKey(updateF1) }
     expectCorrectException { ds.updateStateByKey(updateF1, 5) }
@@ -177,6 +181,14 @@ class DStreamClosureSuite extends SparkFunSuite with 
BeforeAndAfterAll {
     expectCorrectException {
       ds.updateStateByKey(updateF2, new HashPartitioner(5), true, initialRDD)
     }
+    expectCorrectException {
+      ds.updateStateByKey(
+        updateFunc = updateF3,
+        partitioner = new HashPartitioner(5),
+        rememberPartitioner = true,
+        initialRDD = Option(initialRDD)
+      )
+    }
   }
   private def testMapValues(ds: DStream[(Int, Int)]): Unit = 
expectCorrectException {
     ds.mapValues { _ => return; 1 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to