Repository: spark
Updated Branches:
  refs/heads/master 4a46b8859 -> 6a6c1fc5c


[SPARK-11713] [PYSPARK] [STREAMING] Initial RDD updateStateByKey for PySpark

Adding ability to define an initial state RDD for use with updateStateByKey 
PySpark.  Added unit test and changed stateful_network_wordcount example to use 
initial RDD.

Author: Bryan Cutler <bjcut...@us.ibm.com>

Closes #10082 from BryanCutler/initial-rdd-updateStateByKey-SPARK-11713.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6a6c1fc5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6a6c1fc5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6a6c1fc5

Branch: refs/heads/master
Commit: 6a6c1fc5c807ba4e8aba3e260537aa527ff5d46a
Parents: 4a46b88
Author: Bryan Cutler <bjcut...@us.ibm.com>
Authored: Thu Dec 10 14:21:15 2015 -0800
Committer: Davies Liu <davies....@gmail.com>
Committed: Thu Dec 10 14:21:15 2015 -0800

----------------------------------------------------------------------
 .../streaming/stateful_network_wordcount.py     |  5 ++++-
 python/pyspark/streaming/dstream.py             | 13 +++++++++++--
 python/pyspark/streaming/tests.py               | 20 ++++++++++++++++++++
 .../streaming/api/python/PythonDStream.scala    | 14 ++++++++++++--
 4 files changed, 47 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6a6c1fc5/examples/src/main/python/streaming/stateful_network_wordcount.py
----------------------------------------------------------------------
diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py 
b/examples/src/main/python/streaming/stateful_network_wordcount.py
index 16ef646..f8bbc65 100644
--- a/examples/src/main/python/streaming/stateful_network_wordcount.py
+++ b/examples/src/main/python/streaming/stateful_network_wordcount.py
@@ -44,13 +44,16 @@ if __name__ == "__main__":
     ssc = StreamingContext(sc, 1)
     ssc.checkpoint("checkpoint")
 
+    # RDD with initial state (key, value) pairs
+    initialStateRDD = sc.parallelize([(u'hello', 1), (u'world', 1)])
+
     def updateFunc(new_values, last_sum):
         return sum(new_values) + (last_sum or 0)
 
     lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
     running_counts = lines.flatMap(lambda line: line.split(" "))\
                           .map(lambda word: (word, 1))\
-                          .updateStateByKey(updateFunc)
+                          .updateStateByKey(updateFunc, 
initialRDD=initialStateRDD)
 
     running_counts.pprint()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6a6c1fc5/python/pyspark/streaming/dstream.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/dstream.py 
b/python/pyspark/streaming/dstream.py
index acec850..f61137c 100644
--- a/python/pyspark/streaming/dstream.py
+++ b/python/pyspark/streaming/dstream.py
@@ -568,7 +568,7 @@ class DStream(object):
                                                              
self._ssc._jduration(slideDuration))
         return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
 
-    def updateStateByKey(self, updateFunc, numPartitions=None):
+    def updateStateByKey(self, updateFunc, numPartitions=None, 
initialRDD=None):
         """
         Return a new "state" DStream where the state for each key is updated 
by applying
         the given function on the previous state of the key and the new values 
of the key.
@@ -579,6 +579,9 @@ class DStream(object):
         if numPartitions is None:
             numPartitions = self._sc.defaultParallelism
 
+        if initialRDD and not isinstance(initialRDD, RDD):
+            initialRDD = self._sc.parallelize(initialRDD)
+
         def reduceFunc(t, a, b):
             if a is None:
                 g = b.groupByKey(numPartitions).mapValues(lambda vs: 
(list(vs), None))
@@ -590,7 +593,13 @@ class DStream(object):
 
         jreduceFunc = TransformFunction(self._sc, reduceFunc,
                                         self._sc.serializer, 
self._jrdd_deserializer)
-        dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), 
jreduceFunc)
+        if initialRDD:
+            initialRDD = initialRDD._reserialize(self._jrdd_deserializer)
+            dstream = 
self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc,
+                                                       initialRDD._jrdd)
+        else:
+            dstream = 
self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
+
         return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6a6c1fc5/python/pyspark/streaming/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/tests.py 
b/python/pyspark/streaming/tests.py
index a2bfd79..4949cd6 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -403,6 +403,26 @@ class BasicOperationTests(PySparkStreamingTestCase):
         expected = [[('k', v)] for v in expected]
         self._test_func(input, func, expected)
 
+    def test_update_state_by_key_initial_rdd(self):
+
+        def updater(vs, s):
+            if not s:
+                s = []
+            s.extend(vs)
+            return s
+
+        initial = [('k', [0, 1])]
+        initial = self.sc.parallelize(initial, 1)
+
+        input = [[('k', i)] for i in range(2, 5)]
+
+        def func(dstream):
+            return dstream.updateStateByKey(updater, initialRDD=initial)
+
+        expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
+        expected = [[('k', v)] for v in expected]
+        self._test_func(input, func, expected)
+
     def test_failed_func(self):
         # Test failure in
         # TransformFunction.apply(rdd: Option[RDD[_]], time: Time)

http://git-wip-us.apache.org/repos/asf/spark/blob/6a6c1fc5/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
index 994309d..056248c 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
@@ -264,9 +264,19 @@ private[python] class PythonTransformed2DStream(
  */
 private[python] class PythonStateDStream(
     parent: DStream[Array[Byte]],
-    reduceFunc: PythonTransformFunction)
+    reduceFunc: PythonTransformFunction,
+    initialRDD: Option[RDD[Array[Byte]]])
   extends PythonDStream(parent, reduceFunc) {
 
+  def this(
+    parent: DStream[Array[Byte]],
+    reduceFunc: PythonTransformFunction) = this(parent, reduceFunc, None)
+
+  def this(
+    parent: DStream[Array[Byte]],
+    reduceFunc: PythonTransformFunction,
+    initialRDD: JavaRDD[Array[Byte]]) = this(parent, reduceFunc, 
Some(initialRDD.rdd))
+
   super.persist(StorageLevel.MEMORY_ONLY)
   override val mustCheckpoint = true
 
@@ -274,7 +284,7 @@ private[python] class PythonStateDStream(
     val lastState = getOrCompute(validTime - slideDuration)
     val rdd = parent.getOrCompute(validTime)
     if (rdd.isDefined) {
-      func(lastState, rdd, validTime)
+      func(lastState.orElse(initialRDD), rdd, validTime)
     } else {
       lastState
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to