Repository: spark
Updated Branches:
  refs/heads/branch-1.6 5647774b0 -> 012de2ce5


[SPARK-12002][STREAMING][PYSPARK] Fix python direct stream checkpoint recovery 
issue

Fixed a minor race condition in #10017

Closes #10017

Author: jerryshao <ss...@hortonworks.com>
Author: Shixiong Zhu <shixi...@databricks.com>

Closes #10074 from zsxwing/review-pr10017.

(cherry picked from commit f292018f8e57779debc04998456ec875f628133b)
Signed-off-by: Shixiong Zhu <shixi...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/012de2ce
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/012de2ce
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/012de2ce

Branch: refs/heads/branch-1.6
Commit: 012de2ce5de01bc57197fa26334fc175c8f20233
Parents: 5647774
Author: jerryshao <ss...@hortonworks.com>
Authored: Tue Dec 1 15:26:10 2015 -0800
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Tue Dec 1 15:26:20 2015 -0800

----------------------------------------------------------------------
 python/pyspark/streaming/tests.py | 49 ++++++++++++++++++++++++++++++++++
 python/pyspark/streaming/util.py  | 13 ++++-----
 2 files changed, 56 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/012de2ce/python/pyspark/streaming/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/tests.py 
b/python/pyspark/streaming/tests.py
index d380d69..a2bfd79 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -1150,6 +1150,55 @@ class KafkaStreamTests(PySparkStreamingTestCase):
         self.assertNotEqual(topic_and_partition_a, topic_and_partition_d)
 
     @unittest.skipIf(sys.version >= "3", "long type not support")
+    def test_kafka_direct_stream_transform_with_checkpoint(self):
+        """Test the Python direct Kafka stream transform with checkpoint 
correctly recovered."""
+        topic = self._randomTopic()
+        sendData = {"a": 1, "b": 2, "c": 3}
+        kafkaParams = {"metadata.broker.list": 
self._kafkaTestUtils.brokerAddress(),
+                       "auto.offset.reset": "smallest"}
+
+        self._kafkaTestUtils.createTopic(topic)
+        self._kafkaTestUtils.sendMessages(topic, sendData)
+
+        offsetRanges = []
+
+        def transformWithOffsetRanges(rdd):
+            for o in rdd.offsetRanges():
+                offsetRanges.append(o)
+            return rdd
+
+        self.ssc.stop(False)
+        self.ssc = None
+        tmpdir = "checkpoint-test-%d" % random.randint(0, 10000)
+
+        def setup():
+            ssc = StreamingContext(self.sc, 0.5)
+            ssc.checkpoint(tmpdir)
+            stream = KafkaUtils.createDirectStream(ssc, [topic], kafkaParams)
+            stream.transform(transformWithOffsetRanges).count().pprint()
+            return ssc
+
+        try:
+            ssc1 = StreamingContext.getOrCreate(tmpdir, setup)
+            ssc1.start()
+            self.wait_for(offsetRanges, 1)
+            self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), 
long(6))])
+
+            # To make sure some checkpoint is written
+            time.sleep(3)
+            ssc1.stop(False)
+            ssc1 = None
+
+            # Restart again to make sure the checkpoint is recovered correctly
+            ssc2 = StreamingContext.getOrCreate(tmpdir, setup)
+            ssc2.start()
+            ssc2.awaitTermination(3)
+            ssc2.stop(stopSparkContext=False, stopGraceFully=True)
+            ssc2 = None
+        finally:
+            shutil.rmtree(tmpdir)
+
+    @unittest.skipIf(sys.version >= "3", "long type not support")
     def test_kafka_rdd_message_handler(self):
         """Test Python direct Kafka RDD MessageHandler."""
         topic = self._randomTopic()

http://git-wip-us.apache.org/repos/asf/spark/blob/012de2ce/python/pyspark/streaming/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index c7f02bc..abbbf6e 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -37,11 +37,11 @@ class TransformFunction(object):
         self.ctx = ctx
         self.func = func
         self.deserializers = deserializers
-        self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
+        self.rdd_wrap_func = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
         self.failure = None
 
     def rdd_wrapper(self, func):
-        self._rdd_wrapper = func
+        self.rdd_wrap_func = func
         return self
 
     def call(self, milliseconds, jrdds):
@@ -59,7 +59,7 @@ class TransformFunction(object):
             if len(sers) < len(jrdds):
                 sers += (sers[0],) * (len(jrdds) - len(sers))
 
-            rdds = [self._rdd_wrapper(jrdd, self.ctx, ser) if jrdd else None
+            rdds = [self.rdd_wrap_func(jrdd, self.ctx, ser) if jrdd else None
                     for jrdd, ser in zip(jrdds, sers)]
             t = datetime.fromtimestamp(milliseconds / 1000.0)
             r = self.func(t, *rdds)
@@ -101,7 +101,8 @@ class TransformFunctionSerializer(object):
         self.failure = None
         try:
             func = self.gateway.gateway_property.pool[id]
-            return bytearray(self.serializer.dumps((func.func, 
func.deserializers)))
+            return bytearray(self.serializer.dumps((
+                func.func, func.rdd_wrap_func, func.deserializers)))
         except:
             self.failure = traceback.format_exc()
 
@@ -109,8 +110,8 @@ class TransformFunctionSerializer(object):
         # Clear the failure
         self.failure = None
         try:
-            f, deserializers = self.serializer.loads(bytes(data))
-            return TransformFunction(self.ctx, f, *deserializers)
+            f, wrap_func, deserializers = self.serializer.loads(bytes(data))
+            return TransformFunction(self.ctx, f, 
*deserializers).rdd_wrapper(wrap_func)
         except:
             self.failure = traceback.format_exc()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to