Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/21477#discussion_r195623149
--- Diff: python/pyspark/sql/tests.py ---
@@ -1885,6 +1885,263 @@ def test_query_manager_await_termination(self):
q.stop()
shutil.rmtree(tmpPath)
+ class ForeachWriterTester:
+
+ def __init__(self, spark):
+ self.spark = spark
+
+ def write_open_event(self, partitionId, epochId):
+ self._write_event(
+ self.open_events_dir,
+ {'partition': partitionId, 'epoch': epochId})
+
+ def write_process_event(self, row):
+ self._write_event(self.process_events_dir, {'value': 'text'})
+
+ def write_close_event(self, error):
+ self._write_event(self.close_events_dir, {'error': str(error)})
+
+ def write_input_file(self):
+ self._write_event(self.input_dir, "text")
+
+ def open_events(self):
+ return self._read_events(self.open_events_dir, 'partition INT,
epoch INT')
+
+ def process_events(self):
+ return self._read_events(self.process_events_dir, 'value
STRING')
+
+ def close_events(self):
+ return self._read_events(self.close_events_dir, 'error STRING')
+
+ def run_streaming_query_on_writer(self, writer, num_files):
+ self._reset()
+ try:
+ sdf =
self.spark.readStream.format('text').load(self.input_dir)
+ sq = sdf.writeStream.foreach(writer).start()
+ for i in range(num_files):
+ self.write_input_file()
+ sq.processAllAvailable()
+ finally:
+ self.stop_all()
+
+ def assert_invalid_writer(self, writer, msg=None):
+ self._reset()
+ try:
+ sdf =
self.spark.readStream.format('text').load(self.input_dir)
+ sq = sdf.writeStream.foreach(writer).start()
+ self.write_input_file()
+ sq.processAllAvailable()
+ self.fail("invalid writer %s did not fail the query" %
str(writer)) # not expected
+ except Exception as e:
+ if msg:
+ assert(msg in str(e), "%s not in %s" % (msg, str(e)))
+
+ finally:
+ self.stop_all()
+
+ def stop_all(self):
+ for q in self.spark._wrapped.streams.active:
+ q.stop()
+
+ def _reset(self):
+ self.input_dir = tempfile.mkdtemp()
+ self.open_events_dir = tempfile.mkdtemp()
+ self.process_events_dir = tempfile.mkdtemp()
+ self.close_events_dir = tempfile.mkdtemp()
+
+ def _read_events(self, dir, json):
+ rows = self.spark.read.schema(json).json(dir).collect()
+ dicts = [row.asDict() for row in rows]
+ return dicts
+
+ def _write_event(self, dir, event):
+ import uuid
+ with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f:
+ f.write("%s\n" % str(event))
+
+ def __getstate__(self):
+ return (self.open_events_dir, self.process_events_dir,
self.close_events_dir)
+
+ def __setstate__(self, state):
+ self.open_events_dir, self.process_events_dir,
self.close_events_dir = state
+
+ def test_streaming_foreach_with_simple_function(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ def foreach_func(row):
+ tester.write_process_event(row)
+
+ tester.run_streaming_query_on_writer(foreach_func, 2)
+ self.assertEqual(len(tester.process_events()), 2)
+
+ def test_streaming_foreach_with_basic_open_process_close(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partitionId, epochId):
+ tester.write_open_event(partitionId, epochId)
+ return True
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+
+ open_events = tester.open_events()
+ self.assertEqual(len(open_events), 2)
+ self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1})
+
+ self.assertEqual(len(tester.process_events()), 2)
+
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 2)
+ self.assertSetEqual(set([e['error'] for e in close_events]),
{'None'})
+
+ def test_streaming_foreach_with_open_returning_false(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def open(self, partition_id, epoch_id):
+ tester.write_open_event(partition_id, epoch_id)
+ return False
+
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+
+ self.assertEqual(len(tester.open_events()), 2)
+
+ self.assertEqual(len(tester.process_events()), 0) # no row was
processed
+
+ close_events = tester.close_events()
+ self.assertEqual(len(close_events), 2)
+ self.assertSetEqual(set([e['error'] for e in close_events]),
{'None'})
+
+ def test_streaming_foreach_without_open_method(self):
+ tester = self.ForeachWriterTester(self.spark)
+
+ class ForeachWriter:
+ def process(self, row):
+ tester.write_process_event(row)
+
+ def close(self, error):
+ tester.write_close_event(error)
+
+ tester.run_streaming_query_on_writer(ForeachWriter(), 2)
+ self.assertEqual(len(tester.open_events()), 0) # no open
events
--- End diff --
ditto for two spaces
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]