Github user HeartSaVioR commented on a diff in the pull request:
https://github.com/apache/spark/pull/21477#discussion_r193289099
--- Diff: python/pyspark/sql/tests.py ---
@@ -1884,7 +1885,164 @@ def test_query_manager_await_termination(self):
finally:
q.stop()
shutil.rmtree(tmpPath)
+ '''
+ class ForeachWriterTester:
+
+ def __init__(self, spark):
+ self.spark = spark
+ self.input_dir = tempfile.mkdtemp()
+ self.open_events_dir = tempfile.mkdtemp()
+ self.process_events_dir = tempfile.mkdtemp()
+ self.close_events_dir = tempfile.mkdtemp()
+
+ def write_open_event(self, partitionId, epochId):
+ self._write_event(
+ self.open_events_dir,
+ {'partition': partitionId, 'epoch': epochId})
+
+ def write_process_event(self, row):
+ self._write_event(self.process_events_dir, {'value': 'text'})
+
+ def write_close_event(self, error):
+ self._write_event(self.close_events_dir, {'error': str(error)})
+
+ def write_input_file(self):
+ self._write_event(self.input_dir, "text")
+
+ def open_events(self):
+ return self._read_events(self.open_events_dir, 'partition INT,
epoch INT')
+
+ def process_events(self):
+ return self._read_events(self.process_events_dir, 'value
STRING')
+
+ def close_events(self):
+ return self._read_events(self.close_events_dir, 'error STRING')
+
+ def run_streaming_query_on_writer(self, writer, num_files):
+ try:
+ sdf =
self.spark.readStream.format('text').load(self.input_dir)
+ sq = sdf.writeStream.foreach(writer).start()
+ for i in range(num_files):
+ self.write_input_file()
+ sq.processAllAvailable()
+ sq.stop()
+ finally:
+ self.stop_all()
+
+ def _read_events(self, dir, json):
+ rows = self.spark.read.schema(json).json(dir).collect()
+ dicts = [row.asDict() for row in rows]
+ return dicts
+
+ def _write_event(self, dir, event):
+ import random
+ file = open(os.path.join(dir, str(random.randint(0, 100000))),
'w')
+ file.write("%s\n" % str(event))
+ file.close()
+
+ def stop_all(self):
+ for q in self.spark._wrapped.streams.active:
+ q.stop()
+
+ def __getstate__(self):
+ return (self.open_events_dir, self.process_events_dir,
self.close_events_dir)
+
+ def __setstate__(self, state):
+ self.open_events_dir, self.process_events_dir,
self.close_events_dir = state
+
+ def test_streaming_foreach_with_simple_function(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ def foreach_func(row):
+ tester.write_process_event(row)
+
+ tester.run_streaming_query_on_writer(foreach_func, 2)
+ self.assertEqual(len(tester.process_events()), 2)
+
+ def test_streaming_foreach_with_basic_open_process_close(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partitionId, epochId):
+ tester.write_open_event(partitionId, epochId)
+ return True
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+
+ open_events = tester.open_events()
+ self.assertEqual(len(open_events), 2)
+ self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1})
+
+ self.assertEqual(len(tester.process_events()), 2)
+
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 2)
+ self.assertSetEqual(set([e['error'] for e in close_events]),
{'None'})
+
+ def test_streaming_foreach_with_open_returning_false(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partitionId, epochId):
+ tester.write_open_event(partitionId, epochId)
+ return False
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+
+ self.assertEqual(len(tester.open_events()), 2)
+ self.assertEqual(len(tester.process_events()), 0) # no row was
processed
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 2)
+ self.assertSetEqual(set([e['error'] for e in close_events]),
{'None'})
+
+ def test_streaming_foreach_with_process_throwing_error(self):
+ from pyspark.sql.utils import StreamingQueryException
+
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partitionId, epochId):
+ tester.write_open_event(partitionId, epochId)
+ return True
+
+ def process(self, row):
+ raise Exception("test error")
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ try:
+ sdf =
self.spark.readStream.format('text').load(tester.input_dir)
+ sq = sdf.writeStream.foreach(ForeachWriter()).start()
+ tester.write_input_file()
+ sq.processAllAvailable()
+ self.fail("bad writer should fail the query") # this is not
expected
+ except StreamingQueryException as e:
+ # self.assertTrue("test error" in e.desc) # this is
expected
+ pass
+ finally:
+ tester.stop_all()
+
+ self.assertEqual(len(tester.open_events()), 1)
+ self.assertEqual(len(tester.process_events()), 0) # no row was
processed
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 1)
+ # self.assertTrue("test error" in e[0]['error'])
+
+ '''
--- End diff --
Leaving marker as well: `This is only to speed up local testing. Will
remove this`.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]