[SPARK-11713] [PYSPARK] [STREAMING] Initial RDD updateStateByKey for PySpark
Adding ability to define an initial state RDD for use with updateStateByKey PySpark. Added unit test and changed stateful_network_wordcount example to use initial RDD. Author: Bryan Cutler <[email protected]> Closes #10082 from BryanCutler/initial-rdd-updateStateByKey-SPARK-11713. Project: http://git-wip-us.apache.org/repos/asf/bahir/repo Commit: http://git-wip-us.apache.org/repos/asf/bahir/commit/e30f0c2d Tree: http://git-wip-us.apache.org/repos/asf/bahir/tree/e30f0c2d Diff: http://git-wip-us.apache.org/repos/asf/bahir/diff/e30f0c2d Branch: refs/heads/master Commit: e30f0c2d3893814ef43a6d81a0e39c1910ce1f2b Parents: be8b358 Author: Bryan Cutler <[email protected]> Authored: Thu Dec 10 14:21:15 2015 -0800 Committer: Davies Liu <[email protected]> Committed: Thu Dec 10 14:21:15 2015 -0800 ---------------------------------------------------------------------- streaming-mqtt/python/dstream.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/bahir/blob/e30f0c2d/streaming-mqtt/python/dstream.py ---------------------------------------------------------------------- diff --git a/streaming-mqtt/python/dstream.py b/streaming-mqtt/python/dstream.py index acec850..f61137c 100644 --- a/streaming-mqtt/python/dstream.py +++ b/streaming-mqtt/python/dstream.py @@ -568,7 +568,7 @@ class DStream(object): self._ssc._jduration(slideDuration)) return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) - def updateStateByKey(self, updateFunc, numPartitions=None): + def updateStateByKey(self, updateFunc, numPartitions=None, initialRDD=None): """ Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of the key. @@ -579,6 +579,9 @@ class DStream(object): if numPartitions is None: numPartitions = self._sc.defaultParallelism + if initialRDD and not isinstance(initialRDD, RDD): + initialRDD = self._sc.parallelize(initialRDD) + def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) @@ -590,7 +593,13 @@ class DStream(object): jreduceFunc = TransformFunction(self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer) - dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + if initialRDD: + initialRDD = initialRDD._reserialize(self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc, + initialRDD._jrdd) + else: + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
