chaoqin-li1123 commented on code in PR #45977:
URL: https://github.com/apache/spark/pull/45977#discussion_r1573100690


##########
python/pyspark/sql/datasource.py:
##########
@@ -469,6 +501,200 @@ def stop(self) -> None:
         ...
 
 
+class SimpleInputPartition(InputPartition):
+    def __init__(self, start: dict, end: dict):
+        self.start = start
+        self.end = end
+
+
+class PrefetchedCacheEntry(InputPartition):
+    def __init__(self, start: dict, end: dict, it: Iterator[Tuple]):

Review Comment:
   Fixed.



##########
python/pyspark/sql/datasource.py:
##########
@@ -469,6 +501,200 @@ def stop(self) -> None:
         ...
 
 
+class SimpleInputPartition(InputPartition):
+    def __init__(self, start: dict, end: dict):
+        self.start = start
+        self.end = end
+
+
+class PrefetchedCacheEntry(InputPartition):
+    def __init__(self, start: dict, end: dict, it: Iterator[Tuple]):
+        self.start = start
+        self.end = end
+        self.it = it
+
+
+class SimpleDataSourceStreamReader(ABC):
+    """
+    A base class for simplified streaming data source readers.
+    Compared to :class:`DataSourceStreamReader`, 
:class:`SimpleDataSourceStreamReader` doesn't
+    require planning data partition. Also, the read api of 
:class:`SimpleDataSourceStreamReader`
+    allows reading data and planning the latest offset at the same time.
+
+    .. versionadded: 4.0.0
+    """
+
+    def initialOffset(self) -> dict:
+        """
+        Return the initial offset of the streaming data source.
+        A new streaming query starts reading data from the initial offset.
+        If Spark is restarting an existing query, it will restart from the 
check-pointed offset
+        rather than the initial one.
+
+        Returns
+        -------
+        dict
+            A dict or recursive dict whose key and value are primitive types, 
which includes
+            Integer, String and Boolean.
+
+        Examples
+        --------
+        >>> def initialOffset(self):
+        ...     return {"parititon-1": {"index": 3, "closed": True}, 
"partition-2": {"index": 5}}
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "initialOffset"},
+        )
+
+    def read(self, start: dict) -> Tuple[Iterator[Tuple], dict]:
+        """
+        Read all available data from start offset and return the offset that 
next read attempt
+        starts from.
+
+        Parameters
+        ----------
+        start : dict
+            The start offset to start reading from.
+
+        Returns
+        -------
+        A :class:`Tuple` of an iterator of :class:`Tuple` and a dict\\s
+            The iterator contains all the available records after start offset.
+            The dict is the end offset of this read attempt and the start of 
next read attempt.
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "read"},
+        )
+
+    def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
+        """
+        Read all available data from specific start offset and end offset.
+        This is invoked during failure recovery to re-read a batch 
deterministically
+        in order to achieve exactly once.
+
+        Parameters
+        ----------
+        start : dict
+            The start offset to start reading from.
+
+        end : dict
+            The offset where the reading stop.
+
+        Returns
+        -------
+        iterator of :class:`Tuple`\\s
+            All the records between start offset and end offset.
+        """
+        raise PySparkNotImplementedError(
+            error_class="NOT_IMPLEMENTED",
+            message_parameters={"feature": "read2"},
+        )
+
+    def commit(self, end: dict) -> None:
+        """
+        Informs the source that Spark has completed processing all data for 
offsets less than or
+        equal to `end` and will only request offsets greater than `end` in the 
future.
+
+        Parameters
+        ----------
+        end : dict
+            The latest offset that the streaming query has processed for this 
source.
+        """
+        ...
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+    """
+    A private class that wrap :class:`SimpleDataSourceStreamReader` in 
prefetch and cache pattern,
+    so that :class:`SimpleDataSourceStreamReader` can integrate with streaming 
engine like an
+    ordinary :class:`DataSourceStreamReader`.
+
+    current_offset tracks the latest progress of the record prefetching, it is 
initialized to be
+    initialOffset() when query start for the first time or initialized to be 
the end offset of
+    the last committed batch when query restarts.
+
+    When streaming engine calls latestOffset(), the wrapper calls read() that 
starts from
+    current_offset, prefetches and cache the data, then updates the 
current_offset to be
+    the end offset of the new data.
+
+    When streaming engine call planInputPartitions(start, end), the wrapper 
get the prefetched data
+    from cache and send it to JVM along with the input partitions.
+
+    When query restart, batches in write ahead offset log that has not been 
committed will be
+    replayed by reading data between start and end offset through read2(start, 
end).
+    """
+
+    def __init__(self, simple_reader: SimpleDataSourceStreamReader):
+        self.simple_reader = simple_reader
+        self.initial_offset: Optional[dict] = None
+        self.current_offset: Optional[dict] = None
+        self.cache: List[PrefetchedCacheEntry] = []
+
+    def initialOffset(self) -> dict:
+        if self.initial_offset is None:
+            self.initial_offset = self.simple_reader.initialOffset()
+        return self.initial_offset
+
+    def latestOffset(self) -> dict:
+        # when query start for the first time, use initial offset as the start 
offset.
+        if self.current_offset is None:
+            self.current_offset = self.initialOffset()
+        (iter, end) = self.simple_reader.read(self.current_offset)
+        self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
+        self.current_offset = end
+        return end
+
+    def commit(self, end: dict) -> None:
+        if self.current_offset is None:
+            self.current_offset = end
+
+        end_idx = -1
+        for idx, entry in enumerate(self.cache):
+            if json.dumps(entry.end) == json.dumps(end):
+                end_idx = idx
+                break
+        if end_idx > 0:
+            # Drop prefetched data for batch that has been committed.
+            self.cache = self.cache[end_idx:]
+        self.simple_reader.commit(end)
+
+    def partitions(self, start: dict, end: dict) -> Sequence["InputPartition"]:
+        # when query restart from checkpoint, use the last committed offset as 
the start offset.
+        # This depends on the current behavior that streaming engine call 
getBatch on the last
+        # microbatch when query restart.
+        if self.current_offset is None:
+            self.current_offset = end
+        if len(self.cache) > 0:
+            assert self.cache[-1].end == end
+        return [SimpleInputPartition(start, end)]
+
+    def getCache(self, start: dict, end: dict) -> Iterator[Tuple]:
+        start_idx = -1
+        end_idx = -1
+        for idx, entry in enumerate(self.cache):
+            # There is no convenient way to compare 2 offsets.
+            # Serialize into json string before comparison.
+            if json.dumps(entry.start) == json.dumps(start):
+                start_idx = idx
+            if json.dumps(entry.end) == json.dumps(end):
+                end_idx = idx
+                break
+        if start_idx == -1 or end_idx == -1:
+            return None  # type: ignore[return-value]
+        # Chain all the data iterator between start offset and end offset
+        # need to copy here to avoid exhausting the original data iterator.
+        entries = [copy.copy(entry.it) for entry in self.cache[start_idx : 
end_idx + 1]]
+        it = chain(*entries)
+        return it
+
+    def read(self, input_partition: InputPartition) -> Iterator[Tuple]:
+        return self.simple_reader.readBetweenOffsets(
+            input_partition.start, input_partition.end  # type: 
ignore[attr-defined]

Review Comment:
   Fixed.



##########
python/pyspark/sql/datasource.py:
##########
@@ -183,11 +186,40 @@ def streamWriter(self, schema: StructType, overwrite: 
bool) -> "DataSourceStream
             message_parameters={"feature": "streamWriter"},
         )
 
+    def _streamReader(self, schema: StructType) -> "DataSourceStreamReader":
+        try:
+            return self.streamReader(schema=schema)
+        except PySparkNotImplementedError:
+            return 
_SimpleStreamReaderWrapper(self.simpleStreamReader(schema=schema))
+
+    def simpleStreamReader(self, schema: StructType) -> 
"SimpleDataSourceStreamReader":
+        """
+        Returns a :class:`SimpleDataSourceStreamReader` instance for reading 
data.
+
+        One of simpleStreamReader() and streamReader() must be implemented for 
readable streaming
+        data source.

Review Comment:
   Added doc to SimpleStreamReader that it is only intended to light weight use 
case when input rate and batch size is small.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to