Github user HeartSaVioR commented on a diff in the pull request:
https://github.com/apache/spark/pull/21477#discussion_r195625921
--- Diff: python/pyspark/sql/streaming.py ---
@@ -843,6 +843,170 @@ def trigger(self, processingTime=None, once=None,
continuous=None):
self._jwrite = self._jwrite.trigger(jTrigger)
return self
+ @since(2.4)
+ def foreach(self, f):
+ """
+ Sets the output of the streaming query to be processed using the
provided writer ``f``.
+ This is often used to write the output of a streaming query to
arbitrary storage systems.
+ The processing logic can be specified in two ways.
+
+ #. A **function** that takes a row as input.
+ This is a simple way to express your processing logic. Note
that this does
+ not allow you to deduplicate generated data when failures
cause reprocessing of
+ some input data. That would require you to specify the
processing logic in the next
+ way.
+
+ #. An **object** with a ``process`` method and optional ``open``
and ``close`` methods.
+ The object can have the following methods.
+
+ * ``open(partition_id, epoch_id)``: *Optional* method that
initializes the processing
+ (for example, open a connection, start a transaction,
etc). Additionally, you can
+ use the `partition_id` and `epoch_id` to deduplicate
regenerated data
+ (discussed later).
+
+ * ``process(row)``: *Non-optional* method that processes each
:class:`Row`.
+
+ * ``close(error)``: *Optional* method that finalizes and
cleans up (for example,
+ close connection, commit transaction, etc.) after all rows
have been processed.
+
+ The object will be used by Spark in the following way.
+
+ * A single copy of this object is responsible of all the data
generated by a
+ single task in a query. In other words, one instance is
responsible for
+ processing one partition of the data generated in a
distributed manner.
+
+ * This object must be serializable because each task will get
a fresh
+ serialized-deserialized copy of the provided object.
Hence, it is strongly
+ recommended that any initialization for writing data (e.g.
opening a
+ connection or starting a transaction) is done after the
`open(...)`
+ method has been called, which signifies that the task is
ready to generate data.
+
+ * The lifecycle of the methods are as follows.
+
+ For each partition with ``partition_id``:
+
+ ... For each batch/epoch of streaming data with
``epoch_id``:
+
+ ....... Method ``open(partitionId, epochId)`` is called.
+
+ ....... If ``open(...)`` returns true, for each row in the
partition and
+ batch/epoch, method ``process(row)`` is called.
+
+ ....... Method ``close(errorOrNull)`` is called with error
(if any) seen while
+ processing rows.
+
+ Important points to note:
+
+ * The `partitionId` and `epochId` can be used to deduplicate
generated data when
+ failures cause reprocessing of some input data. This
depends on the execution
+ mode of the query. If the streaming query is being
executed in the micro-batch
+ mode, then every partition represented by a unique tuple
(partition_id, epoch_id)
+ is guaranteed to have the same data. Hence, (partition_id,
epoch_id) can be used
+ to deduplicate and/or transactionally commit data and
achieve exactly-once
+ guarantees. However, if the streaming query is being
executed in the continuous
+ mode, then this guarantee does not hold and therefore
should not be used for
+ deduplication.
+
+ * The ``close()`` method (if exists) will be called if
`open()` method exists and
+ returns successfully (irrespective of the return value),
except if the Python
+ crashes in the middle.
+
+ .. note:: Evolving.
+
+ >>> # Print every row using a function
+ >>> def print_row(row):
+ ... print(row)
+ ...
+ >>> writer = sdf.writeStream.foreach(print_row)
+ >>> # Print every row using a object with process() method
+ >>> class RowPrinter:
+ ... def open(self, partition_id, epoch_id):
+ ... print("Opened %d, %d" % (partition_id, epoch_id))
+ ... return True
+ ... def process(self, row):
+ ... print(row)
+ ... def close(self, error):
+ ... print("Closed with error: %s" % str(error))
+ ...
+ >>> writer = sdf.writeStream.foreach(RowPrinter())
+ """
+
+ from pyspark.rdd import _wrap_function
+ from pyspark.serializers import PickleSerializer,
AutoBatchedSerializer
+ from pyspark.taskcontext import TaskContext
+
+ if callable(f):
+ # The provided object is a callable function that is supposed
to be called on each row.
+ # Construct a function that takes an iterator and calls the
provided function on each
+ # row.
+ def func_without_process(_, iterator):
+ for x in iterator:
+ f(x)
+ return iter([])
+
+ func = func_without_process
+
+ else:
+ # The provided object is not a callable function. Then it is
expected to have a
+ # 'process(row)' method, and optional 'open(partition_id,
epoch_id)' and
+ # 'close(error)' methods.
+
+ if not hasattr(f, 'process'):
+ raise Exception("Provided object does not have a 'process'
method")
+
+ if not callable(getattr(f, 'process')):
+ raise Exception("Attribute 'process' in provided object is
not callable")
+
+ def doesMethodExist(method_name):
+ exists = hasattr(f, method_name)
+ if exists and not callable(getattr(f, method_name)):
+ raise Exception(
+ "Attribute '%s' in provided object is not
callable" % method_name)
+ return exists
+
+ open_exists = doesMethodExist('open')
+ close_exists = doesMethodExist('close')
+
+ def func_with_open_process_close(partition_id, iterator):
+ epoch_id =
TaskContext.get().getLocalProperty('streaming.sql.batchId')
+ if epoch_id:
+ epoch_id = int(epoch_id)
+ else:
+ raise Exception("Could not get batch id from
TaskContext")
+
+ # Check if the data should be processed
+ should_process = True
+ if open_exists:
+ should_process = f.open(partition_id, epoch_id)
+
+ error = None
+
+ try:
+ if should_process:
+ for x in iterator:
+ f.process(x)
+
--- End diff --
This looks like strictly related to preference, and I prefer newline. Are
you referring to the Spark style guide or PEP 8 or so?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]