HeartSaVioR commented on code in PR #45977:
URL: https://github.com/apache/spark/pull/45977#discussion_r1582734755
##########
python/pyspark/sql/streaming/python_streaming_source_runner.py:
##########
@@ -60,14 +68,29 @@ def latest_offset_func(reader: DataSourceStreamReader,
outfile: IO) -> None:
write_with_length(json.dumps(offset).encode("utf-8"), outfile)
-def partitions_func(reader: DataSourceStreamReader, infile: IO, outfile: IO)
-> None:
+def partitions_func(
+ reader: DataSourceStreamReader,
+ data_source: DataSource,
+ schema: StructType,
+ max_arrow_batch_size: int,
+ infile: IO,
+ outfile: IO,
+) -> None:
start_offset = json.loads(utf8_deserializer.loads(infile))
end_offset = json.loads(utf8_deserializer.loads(infile))
partitions = reader.partitions(start_offset, end_offset)
# Return the serialized partition values.
write_int(len(partitions), outfile)
for partition in partitions:
pickleSer._write_with_length(partition, outfile)
+ if isinstance(reader, _SimpleStreamReaderWrapper):
+ it = reader.getCache(start_offset, end_offset)
+ if it is None:
Review Comment:
Likewise I mentioned above, we could always send the batch here and
eliminate necessity of serializing SimpleStreamReader and also wrapper.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -491,9 +492,15 @@ class MicroBatchExecution(
case (source: Source, end: Offset) =>
val start =
execCtx.startOffsets.get(source).map(_.asInstanceOf[Offset])
source.getBatch(start, end)
- case nonV1Tuple =>
- // The V2 API does not have the same edge case requiring
getBatch to be called
- // here, so we do nothing here.
+ case (source: PythonMicroBatchStream, end: Offset) =>
+ // PythonMicrobatchStream need to initialize the start
offset of prefetching
+ // by calling planInputPartitions of the last completed
batch during restart.
+ // We don't need to do that if there is incomplete batch in
the offset log
+ // because planInputPartitions during batch replay
initializes the start offset.
+ val start =
execCtx.startOffsets.get(source).map(_.asInstanceOf[Offset])
+
source.planInputPartitions(source.deserializeOffset(start.get.json()),
Review Comment:
Is it safe to assume that `start` is always Some(v) rather than None?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala:
##########
@@ -35,6 +38,14 @@ class PythonMicroBatchStream(
ds.source.createPythonFunction(
ds.getOrCreateDataSourceInPython(shortName, options,
Some(outputSchema)).dataSource)
+ private val streamId = nextStreamId
+ private var nextBlockId = 0L
+
+ // planInputPartitions() maybe be called multiple times for the current
microbatch.
+ // Cache the result of planInputPartitions() because it may involve sending
data
+ // from python to JVM.
+ private var cachedInputPartition: Option[(String, String,
PythonStreamingInputPartition)] = None
Review Comment:
While it could be reduced in above change, it's still safer to leave the
cache as it is. It's not a strong guarantee that planInputPartitions() is only
called once (otherwise the above should be a bugfix about contract violation
not an optimization).
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala:
##########
@@ -66,9 +97,18 @@ class PythonMicroBatchStream(
}
override def stop(): Unit = {
+ cachedInputPartition.foreach(_._3.dropCache())
runner.stop()
}
override def deserializeOffset(json: String): Offset =
PythonStreamingSourceOffset(json)
}
+object PythonMicroBatchStream {
+ var currentId = 0
Review Comment:
I don't think this is thread-safe, unless this variable is only accessible
with nextStreamId. If you are not intentional to expose this to public, please
explicitly block it. (private)
##########
python/pyspark/sql/datasource_internal.py:
##########
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+ DataSource,
+ DataSourceStreamReader,
+ InputPartition,
+ SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) ->
"DataSourceStreamReader":
+ """
+ Fallback to simpleStreamReader() method when streamReader() is not
implemented.
+ This should be invoked whenever a DataSourceStreamReader needs to be
created instead of
+ invoking datasource.streamReader() directly.
+ """
+ try:
+ return datasource.streamReader(schema=schema)
+ except PySparkNotImplementedError:
+ return
_SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):
+ def __init__(self, start: dict, end: dict):
+ self.start = start
+ self.end = end
+
+
+class PrefetchedCacheEntry:
+ def __init__(self, start: dict, end: dict, iterator: Iterator[Tuple]):
+ self.start = start
+ self.end = end
+ self.iterator = iterator
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+ """
+ A private class that wrap :class:`SimpleDataSourceStreamReader` in
prefetch and cache pattern,
+ so that :class:`SimpleDataSourceStreamReader` can integrate with streaming
engine like an
+ ordinary :class:`DataSourceStreamReader`.
+
+ current_offset tracks the latest progress of the record prefetching, it is
initialized to be
+ initialOffset() when query start for the first time or initialized to be
the end offset of
+ the last committed batch when query restarts.
+
+ When streaming engine calls latestOffset(), the wrapper calls read() that
starts from
+ current_offset, prefetches and cache the data, then updates the
current_offset to be
+ the end offset of the new data.
+
+ When streaming engine call planInputPartitions(start, end), the wrapper
get the prefetched data
+ from cache and send it to JVM along with the input partitions.
+
+ When query restart, batches in write ahead offset log that has not been
committed will be
+ replayed by reading data between start and end offset through
readBetweenOffsets(start, end).
+ """
+
+ def __init__(self, simple_reader: SimpleDataSourceStreamReader):
+ self.simple_reader = simple_reader
+ self.initial_offset: Optional[dict] = None
+ self.current_offset: Optional[dict] = None
+ self.cache: List[PrefetchedCacheEntry] = []
+
+ def initialOffset(self) -> dict:
+ if self.initial_offset is None:
+ self.initial_offset = self.simple_reader.initialOffset()
+ return self.initial_offset
+
+ def latestOffset(self) -> dict:
+ # when query start for the first time, use initial offset as the start
offset.
+ if self.current_offset is None:
+ self.current_offset = self.initialOffset()
+ (iter, end) = self.simple_reader.read(self.current_offset)
+ self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
+ self.current_offset = end
+ return end
+
+ def commit(self, end: dict) -> None:
+ if self.current_offset is None:
+ self.current_offset = end
+
+ end_idx = -1
+ for idx, entry in enumerate(self.cache):
+ if json.dumps(entry.end) == json.dumps(end):
+ end_idx = idx
+ break
+ if end_idx > 0:
+ # Drop prefetched data for batch that has been committed.
+ self.cache = self.cache[end_idx:]
+ self.simple_reader.commit(end)
+
+ def partitions(self, start: dict, end: dict) -> Sequence["InputPartition"]:
+ # when query restart from checkpoint, use the last committed offset as
the start offset.
+ # This depends on the streaming engine calling planInputPartitions()
of the last batch
+ # in offset log when query restart.
+ if self.current_offset is None:
+ self.current_offset = end
+ if len(self.cache) > 0:
+ assert self.cache[-1].end == end
+ return [SimpleInputPartition(start, end)]
+
+ def getCache(self, start: dict, end: dict) -> Iterator[Tuple]:
+ start_idx = -1
+ end_idx = -1
+ for idx, entry in enumerate(self.cache):
+ # There is no convenient way to compare 2 offsets.
+ # Serialize into json string before comparison.
+ if json.dumps(entry.start) == json.dumps(start):
+ start_idx = idx
+ if json.dumps(entry.end) == json.dumps(end):
+ end_idx = idx
+ break
+ if start_idx == -1 or end_idx == -1:
+ return None # type: ignore[return-value]
+ # Chain all the data iterator between start offset and end offset
+ # need to copy here to avoid exhausting the original data iterator.
+ entries = [copy.copy(entry.iterator) for entry in self.cache[start_idx
: end_idx + 1]]
+ it = chain(*entries)
+ return it
+
+ def read(
+ self, input_partition: SimpleInputPartition # type: ignore[override]
+ ) -> Iterator[Tuple]:
+ return self.simple_reader.readBetweenOffsets(input_partition.start,
input_partition.end)
Review Comment:
This sounds like we also have the case where the read method of wrapper
class has to be serialized and being executed in the task, say, simple reader
also needs to be serialized and being executed in the task. Do I understand
correctly?
If I understand correctly, I'd say you'd need to document this in the
SimpleDataSourceStreamReader, as SimpleDataSourceStreamReader isn't
driver-only, which means they still need to consider serialization.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala:
##########
@@ -44,9 +55,29 @@ class PythonMicroBatchStream(
override def latestOffset(): Offset =
PythonStreamingSourceOffset(runner.latestOffset())
override def planInputPartitions(start: Offset, end: Offset):
Array[InputPartition] = {
- runner.partitions(start.asInstanceOf[PythonStreamingSourceOffset].json,
- end.asInstanceOf[PythonStreamingSourceOffset].json)
- .zipWithIndex.map(p => PythonInputPartition(p._2, p._1))
+ val start_offset_json =
start.asInstanceOf[PythonStreamingSourceOffset].json
Review Comment:
nit: shall we use camelCase while we are in Scala codebase?
##########
python/pyspark/sql/datasource_internal.py:
##########
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+ DataSource,
+ DataSourceStreamReader,
+ InputPartition,
+ SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) ->
"DataSourceStreamReader":
+ """
+ Fallback to simpleStreamReader() method when streamReader() is not
implemented.
+ This should be invoked whenever a DataSourceStreamReader needs to be
created instead of
+ invoking datasource.streamReader() directly.
+ """
+ try:
+ return datasource.streamReader(schema=schema)
+ except PySparkNotImplementedError:
+ return
_SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):
+ def __init__(self, start: dict, end: dict):
+ self.start = start
+ self.end = end
+
+
+class PrefetchedCacheEntry:
+ def __init__(self, start: dict, end: dict, iterator: Iterator[Tuple]):
+ self.start = start
+ self.end = end
+ self.iterator = iterator
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+ """
+ A private class that wrap :class:`SimpleDataSourceStreamReader` in
prefetch and cache pattern,
+ so that :class:`SimpleDataSourceStreamReader` can integrate with streaming
engine like an
+ ordinary :class:`DataSourceStreamReader`.
+
+ current_offset tracks the latest progress of the record prefetching, it is
initialized to be
+ initialOffset() when query start for the first time or initialized to be
the end offset of
+ the last committed batch when query restarts.
+
+ When streaming engine calls latestOffset(), the wrapper calls read() that
starts from
+ current_offset, prefetches and cache the data, then updates the
current_offset to be
+ the end offset of the new data.
+
+ When streaming engine call planInputPartitions(start, end), the wrapper
get the prefetched data
+ from cache and send it to JVM along with the input partitions.
+
+ When query restart, batches in write ahead offset log that has not been
committed will be
+ replayed by reading data between start and end offset through
readBetweenOffsets(start, end).
+ """
+
+ def __init__(self, simple_reader: SimpleDataSourceStreamReader):
+ self.simple_reader = simple_reader
+ self.initial_offset: Optional[dict] = None
+ self.current_offset: Optional[dict] = None
+ self.cache: List[PrefetchedCacheEntry] = []
+
+ def initialOffset(self) -> dict:
+ if self.initial_offset is None:
+ self.initial_offset = self.simple_reader.initialOffset()
+ return self.initial_offset
+
+ def latestOffset(self) -> dict:
+ # when query start for the first time, use initial offset as the start
offset.
+ if self.current_offset is None:
+ self.current_offset = self.initialOffset()
+ (iter, end) = self.simple_reader.read(self.current_offset)
+ self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
+ self.current_offset = end
+ return end
+
+ def commit(self, end: dict) -> None:
+ if self.current_offset is None:
+ self.current_offset = end
+
+ end_idx = -1
+ for idx, entry in enumerate(self.cache):
+ if json.dumps(entry.end) == json.dumps(end):
+ end_idx = idx
+ break
+ if end_idx > 0:
+ # Drop prefetched data for batch that has been committed.
+ self.cache = self.cache[end_idx:]
+ self.simple_reader.commit(end)
+
+ def partitions(self, start: dict, end: dict) -> Sequence["InputPartition"]:
+ # when query restart from checkpoint, use the last committed offset as
the start offset.
+ # This depends on the streaming engine calling planInputPartitions()
of the last batch
+ # in offset log when query restart.
+ if self.current_offset is None:
+ self.current_offset = end
+ if len(self.cache) > 0:
+ assert self.cache[-1].end == end
+ return [SimpleInputPartition(start, end)]
+
+ def getCache(self, start: dict, end: dict) -> Iterator[Tuple]:
+ start_idx = -1
+ end_idx = -1
+ for idx, entry in enumerate(self.cache):
+ # There is no convenient way to compare 2 offsets.
+ # Serialize into json string before comparison.
+ if json.dumps(entry.start) == json.dumps(start):
Review Comment:
Does this mean we have a case where the offset range spans to multiple cache
entries? Or is it just a sort of defensive programming?
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala:
##########
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.metric.CustomTaskMetric
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader,
PartitionReaderFactory}
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.storage.PythonStreamBlockId
+
+
+case class PythonStreamingInputPartition(
+ index: Int,
+ pickedPartition: Array[Byte],
+ blockId: Option[PythonStreamBlockId]) extends InputPartition {
+ def dropCache(): Unit = {
+ blockId.foreach(SparkEnv.get.blockManager.master.removeBlock(_))
+ }
+}
+
+class PythonStreamingPartitionReaderFactory(
+ source: UserDefinedPythonDataSource,
+ pickledReadFunc: Array[Byte],
+ outputSchema: StructType,
+ jobArtifactUUID: Option[String])
+ extends PartitionReaderFactory with Logging {
+
+ override def createReader(partition: InputPartition):
PartitionReader[InternalRow] = {
+ val part = partition.asInstanceOf[PythonStreamingInputPartition]
+
+ // Maybe read from cached block prefetched by SimpleStreamReader
+ lazy val cachedBlock = if (part.blockId.isDefined) {
+ val block = SparkEnv.get.blockManager.get[InternalRow](part.blockId.get)
+ .map(_.data.asInstanceOf[Iterator[InternalRow]])
+ if (block.isEmpty) {
+ logWarning(s"Prefetched block ${part.blockId} for Python data source
not found.")
+ }
+ block
+ } else None
+
+ new PartitionReader[InternalRow] {
+
+ private[this] val metrics: Map[String, SQLMetric] =
PythonCustomMetric.pythonMetrics
+
+ private val outputIter = if (cachedBlock.isEmpty) {
+ // Evaluate the python read UDF if the partition is not cached as
block.
+ val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
+ pickledReadFunc,
+ "read_from_data_source",
+ UserDefinedPythonDataSource.readInputSchema,
+ outputSchema,
+ metrics,
+ jobArtifactUUID)
+
+ evaluatorFactory.createEvaluator().eval(
+ part.index, Iterator.single(InternalRow(part.pickedPartition)))
+ } else cachedBlock.get
+
+ override def next(): Boolean = outputIter.hasNext
+
+ override def get(): InternalRow = outputIter.next()
+
+ override def close(): Unit = {}
+
+ override def currentMetricsValues(): Array[CustomTaskMetric] = {
+ source.createPythonTaskMetrics(metrics.map { case (k, v) => k ->
v.value})
Review Comment:
nit: space between e and }
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala:
##########
@@ -35,6 +38,14 @@ class PythonMicroBatchStream(
ds.source.createPythonFunction(
ds.getOrCreateDataSourceInPython(shortName, options,
Some(outputSchema)).dataSource)
+ private val streamId = nextStreamId
+ private var nextBlockId = 0L
+
+ // planInputPartitions() maybe be called multiple times for the current
microbatch.
+ // Cache the result of planInputPartitions() because it may involve sending
data
+ // from python to JVM.
+ private var cachedInputPartition: Option[(String, String,
PythonStreamingInputPartition)] = None
Review Comment:
@HyukjinKwon @allisonwang-db
Is there a case where Python data source would support columnar read? (I
assume we don't.)
If that's not the case, should we override columnarSupportMode in
PythonTable to explicitly say it's unsupported? This is a known issue with
multiple calls of planInputPartitions().
https://github.com/apache/spark/pull/42823
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala:
##########
@@ -164,7 +178,20 @@ class PythonStreamingSourceRunner(
val pickledPartition: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
pickledPartitions.append(pickledPartition)
}
- pickledPartitions.toArray
+ val prefetchedRecordsStatus = dataIn.readInt()
+ val iter: Option[Iterator[InternalRow]] = prefetchedRecordsStatus match {
+ case NON_EMPTY_PYARROW_RECORD_BATCHES => Some(readArrowRecordBatches())
+ case PREFETCHED_RECORDS_NOT_FOUND => None
Review Comment:
Never mind. Got it.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala:
##########
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.metric.CustomTaskMetric
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader,
PartitionReaderFactory}
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.storage.PythonStreamBlockId
+
+
+case class PythonStreamingInputPartition(
+ index: Int,
+ pickedPartition: Array[Byte],
+ blockId: Option[PythonStreamBlockId]) extends InputPartition {
+ def dropCache(): Unit = {
+ blockId.foreach(SparkEnv.get.blockManager.master.removeBlock(_))
+ }
+}
+
+class PythonStreamingPartitionReaderFactory(
+ source: UserDefinedPythonDataSource,
+ pickledReadFunc: Array[Byte],
+ outputSchema: StructType,
+ jobArtifactUUID: Option[String])
+ extends PartitionReaderFactory with Logging {
+
+ override def createReader(partition: InputPartition):
PartitionReader[InternalRow] = {
+ val part = partition.asInstanceOf[PythonStreamingInputPartition]
+
+ // Maybe read from cached block prefetched by SimpleStreamReader
+ lazy val cachedBlock = if (part.blockId.isDefined) {
+ val block = SparkEnv.get.blockManager.get[InternalRow](part.blockId.get)
+ .map(_.data.asInstanceOf[Iterator[InternalRow]])
+ if (block.isEmpty) {
+ logWarning(s"Prefetched block ${part.blockId} for Python data source
not found.")
+ }
+ block
+ } else None
+
+ new PartitionReader[InternalRow] {
+
+ private[this] val metrics: Map[String, SQLMetric] =
PythonCustomMetric.pythonMetrics
+
+ private val outputIter = if (cachedBlock.isEmpty) {
+ // Evaluate the python read UDF if the partition is not cached as
block.
+ val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
+ pickledReadFunc,
Review Comment:
Got it - we need to serialize SimpleDataSourceStreamReader to cover a bad
case.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala:
##########
@@ -199,4 +223,30 @@ class PythonStreamingSourceRunner(
logError("Exception when trying to kill worker", e)
}
}
+
+ private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+ s"stream reader for $pythonExec", 0, Long.MaxValue)
+
+ def readArrowRecordBatches(): Iterator[InternalRow] = {
+ assert(dataIn.readInt() == SpecialLengths.START_ARROW_STREAM)
+ val reader = new ArrowStreamReader(dataIn, allocator)
+ val root = reader.getVectorSchemaRoot()
+ // When input is empty schema can't be read.
+ val schema = ArrowUtils.fromArrowSchema(root.getSchema())
+ assert(schema == outputSchema)
+
+ val vectors = root.getFieldVectors().asScala.map { vector =>
+ new ArrowColumnVector(vector)
+ }.toArray[ColumnVector]
+ val rows = ArrayBuffer[InternalRow]()
Review Comment:
I imagine there may be still a way to avoid materializing all rows at once
(e.g. per arrow batch) but I don't concern too much about it as we know simple
data source isn't intended to handle a huge amount of data.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala:
##########
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.metric.CustomTaskMetric
+import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader,
PartitionReaderFactory}
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.storage.PythonStreamBlockId
+
+
+case class PythonStreamingInputPartition(
+ index: Int,
+ pickedPartition: Array[Byte],
+ blockId: Option[PythonStreamBlockId]) extends InputPartition {
+ def dropCache(): Unit = {
+ blockId.foreach(SparkEnv.get.blockManager.master.removeBlock(_))
+ }
+}
+
+class PythonStreamingPartitionReaderFactory(
+ source: UserDefinedPythonDataSource,
+ pickledReadFunc: Array[Byte],
+ outputSchema: StructType,
+ jobArtifactUUID: Option[String])
+ extends PartitionReaderFactory with Logging {
+
+ override def createReader(partition: InputPartition):
PartitionReader[InternalRow] = {
+ val part = partition.asInstanceOf[PythonStreamingInputPartition]
+
+ // Maybe read from cached block prefetched by SimpleStreamReader
+ lazy val cachedBlock = if (part.blockId.isDefined) {
+ val block = SparkEnv.get.blockManager.get[InternalRow](part.blockId.get)
+ .map(_.data.asInstanceOf[Iterator[InternalRow]])
+ if (block.isEmpty) {
+ logWarning(s"Prefetched block ${part.blockId} for Python data source
not found.")
+ }
+ block
+ } else None
+
+ new PartitionReader[InternalRow] {
+
+ private[this] val metrics: Map[String, SQLMetric] =
PythonCustomMetric.pythonMetrics
+
+ private val outputIter = if (cachedBlock.isEmpty) {
+ // Evaluate the python read UDF if the partition is not cached as
block.
+ val evaluatorFactory = source.createMapInBatchEvaluatorFactory(
+ pickledReadFunc,
Review Comment:
Do we have a way to trigger this artificially? Never mind if it's not
feasible - looks like non-trivial but would be awesome if we can test with this
path as well.
##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala:
##########
@@ -44,9 +55,29 @@ class PythonMicroBatchStream(
override def latestOffset(): Offset =
PythonStreamingSourceOffset(runner.latestOffset())
override def planInputPartitions(start: Offset, end: Offset):
Array[InputPartition] = {
- runner.partitions(start.asInstanceOf[PythonStreamingSourceOffset].json,
- end.asInstanceOf[PythonStreamingSourceOffset].json)
- .zipWithIndex.map(p => PythonInputPartition(p._2, p._1))
+ val start_offset_json =
start.asInstanceOf[PythonStreamingSourceOffset].json
+ val end_offset_json = end.asInstanceOf[PythonStreamingSourceOffset].json
Review Comment:
same here
##########
python/pyspark/sql/datasource_internal.py:
##########
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+ DataSource,
+ DataSourceStreamReader,
+ InputPartition,
+ SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) ->
"DataSourceStreamReader":
+ """
+ Fallback to simpleStreamReader() method when streamReader() is not
implemented.
+ This should be invoked whenever a DataSourceStreamReader needs to be
created instead of
+ invoking datasource.streamReader() directly.
+ """
+ try:
+ return datasource.streamReader(schema=schema)
+ except PySparkNotImplementedError:
+ return
_SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):
+ def __init__(self, start: dict, end: dict):
+ self.start = start
+ self.end = end
+
+
+class PrefetchedCacheEntry:
+ def __init__(self, start: dict, end: dict, iterator: Iterator[Tuple]):
+ self.start = start
+ self.end = end
+ self.iterator = iterator
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+ """
+ A private class that wrap :class:`SimpleDataSourceStreamReader` in
prefetch and cache pattern,
+ so that :class:`SimpleDataSourceStreamReader` can integrate with streaming
engine like an
+ ordinary :class:`DataSourceStreamReader`.
+
+ current_offset tracks the latest progress of the record prefetching, it is
initialized to be
+ initialOffset() when query start for the first time or initialized to be
the end offset of
+ the last committed batch when query restarts.
+
+ When streaming engine calls latestOffset(), the wrapper calls read() that
starts from
+ current_offset, prefetches and cache the data, then updates the
current_offset to be
+ the end offset of the new data.
+
+ When streaming engine call planInputPartitions(start, end), the wrapper
get the prefetched data
+ from cache and send it to JVM along with the input partitions.
+
+ When query restart, batches in write ahead offset log that has not been
committed will be
+ replayed by reading data between start and end offset through
readBetweenOffsets(start, end).
+ """
+
+ def __init__(self, simple_reader: SimpleDataSourceStreamReader):
+ self.simple_reader = simple_reader
+ self.initial_offset: Optional[dict] = None
+ self.current_offset: Optional[dict] = None
+ self.cache: List[PrefetchedCacheEntry] = []
+
+ def initialOffset(self) -> dict:
+ if self.initial_offset is None:
+ self.initial_offset = self.simple_reader.initialOffset()
+ return self.initial_offset
+
+ def latestOffset(self) -> dict:
+ # when query start for the first time, use initial offset as the start
offset.
+ if self.current_offset is None:
+ self.current_offset = self.initialOffset()
+ (iter, end) = self.simple_reader.read(self.current_offset)
+ self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
+ self.current_offset = end
+ return end
+
+ def commit(self, end: dict) -> None:
+ if self.current_offset is None:
+ self.current_offset = end
+
+ end_idx = -1
+ for idx, entry in enumerate(self.cache):
+ if json.dumps(entry.end) == json.dumps(end):
+ end_idx = idx
+ break
+ if end_idx > 0:
Review Comment:
Correct me if I'm missing something. According to the interface contract,
the offset "end" won't be requested. Doesn't it mean this should be `end_idx >
-1` and `self.cache = self.cache[end_idx+1:]`? Any reason we keep the cached
entry which matches with end in end offset?
##########
python/pyspark/sql/datasource_internal.py:
##########
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import json
+import copy
+from itertools import chain
+from typing import Iterator, List, Optional, Sequence, Tuple
+
+from pyspark.sql.datasource import (
+ DataSource,
+ DataSourceStreamReader,
+ InputPartition,
+ SimpleDataSourceStreamReader,
+)
+from pyspark.sql.types import StructType
+from pyspark.errors import PySparkNotImplementedError
+
+
+def _streamReader(datasource: DataSource, schema: StructType) ->
"DataSourceStreamReader":
+ """
+ Fallback to simpleStreamReader() method when streamReader() is not
implemented.
+ This should be invoked whenever a DataSourceStreamReader needs to be
created instead of
+ invoking datasource.streamReader() directly.
+ """
+ try:
+ return datasource.streamReader(schema=schema)
+ except PySparkNotImplementedError:
+ return
_SimpleStreamReaderWrapper(datasource.simpleStreamReader(schema=schema))
+
+
+class SimpleInputPartition(InputPartition):
+ def __init__(self, start: dict, end: dict):
+ self.start = start
+ self.end = end
+
+
+class PrefetchedCacheEntry:
+ def __init__(self, start: dict, end: dict, iterator: Iterator[Tuple]):
+ self.start = start
+ self.end = end
+ self.iterator = iterator
+
+
+class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+ """
+ A private class that wrap :class:`SimpleDataSourceStreamReader` in
prefetch and cache pattern,
+ so that :class:`SimpleDataSourceStreamReader` can integrate with streaming
engine like an
+ ordinary :class:`DataSourceStreamReader`.
+
+ current_offset tracks the latest progress of the record prefetching, it is
initialized to be
+ initialOffset() when query start for the first time or initialized to be
the end offset of
+ the last committed batch when query restarts.
+
+ When streaming engine calls latestOffset(), the wrapper calls read() that
starts from
+ current_offset, prefetches and cache the data, then updates the
current_offset to be
+ the end offset of the new data.
+
+ When streaming engine call planInputPartitions(start, end), the wrapper
get the prefetched data
+ from cache and send it to JVM along with the input partitions.
+
+ When query restart, batches in write ahead offset log that has not been
committed will be
+ replayed by reading data between start and end offset through
readBetweenOffsets(start, end).
+ """
+
+ def __init__(self, simple_reader: SimpleDataSourceStreamReader):
+ self.simple_reader = simple_reader
+ self.initial_offset: Optional[dict] = None
+ self.current_offset: Optional[dict] = None
+ self.cache: List[PrefetchedCacheEntry] = []
+
+ def initialOffset(self) -> dict:
+ if self.initial_offset is None:
+ self.initial_offset = self.simple_reader.initialOffset()
+ return self.initial_offset
+
+ def latestOffset(self) -> dict:
+ # when query start for the first time, use initial offset as the start
offset.
+ if self.current_offset is None:
+ self.current_offset = self.initialOffset()
+ (iter, end) = self.simple_reader.read(self.current_offset)
+ self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
+ self.current_offset = end
+ return end
+
+ def commit(self, end: dict) -> None:
+ if self.current_offset is None:
+ self.current_offset = end
+
+ end_idx = -1
+ for idx, entry in enumerate(self.cache):
+ if json.dumps(entry.end) == json.dumps(end):
+ end_idx = idx
+ break
+ if end_idx > 0:
+ # Drop prefetched data for batch that has been committed.
+ self.cache = self.cache[end_idx:]
+ self.simple_reader.commit(end)
+
+ def partitions(self, start: dict, end: dict) -> Sequence["InputPartition"]:
+ # when query restart from checkpoint, use the last committed offset as
the start offset.
+ # This depends on the streaming engine calling planInputPartitions()
of the last batch
+ # in offset log when query restart.
+ if self.current_offset is None:
+ self.current_offset = end
+ if len(self.cache) > 0:
+ assert self.cache[-1].end == end
+ return [SimpleInputPartition(start, end)]
+
+ def getCache(self, start: dict, end: dict) -> Iterator[Tuple]:
+ start_idx = -1
+ end_idx = -1
+ for idx, entry in enumerate(self.cache):
+ # There is no convenient way to compare 2 offsets.
+ # Serialize into json string before comparison.
+ if json.dumps(entry.start) == json.dumps(start):
+ start_idx = idx
+ if json.dumps(entry.end) == json.dumps(end):
+ end_idx = idx
+ break
+ if start_idx == -1 or end_idx == -1:
+ return None # type: ignore[return-value]
+ # Chain all the data iterator between start offset and end offset
+ # need to copy here to avoid exhausting the original data iterator.
+ entries = [copy.copy(entry.iterator) for entry in self.cache[start_idx
: end_idx + 1]]
+ it = chain(*entries)
+ return it
+
+ def read(
+ self, input_partition: SimpleInputPartition # type: ignore[override]
+ ) -> Iterator[Tuple]:
+ return self.simple_reader.readBetweenOffsets(input_partition.start,
input_partition.end)
Review Comment:
That said, if we make change to either `getCache` or `caller of getCache` to
call `readBetweenOffsets` and execute the same path (send the data via
arrowbatch), this method must not be called and we wouldn't need to serialize
SimpleDataSourceStreamReader instance.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]