This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e3e7135af4df [SPARK-47107][SS][PYTHON] Implement partition reader for
python streaming data source
e3e7135af4df is described below
commit e3e7135af4df3427f4c61cccfe189f702844e1f5
Author: Chaoqin Li <[email protected]>
AuthorDate: Thu Mar 28 06:33:49 2024 +0900
[SPARK-47107][SS][PYTHON] Implement partition reader for python streaming
data source
### What changes were proposed in this pull request?
Piggy back the PythonPartitionReaderFactory to implement reading a data
partition for python streaming data source. Add test case to verify that python
streaming data source can read and process data end to end.
### Why are the changes needed?
This is part of the effort to support developing streaming data source in
python interface.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Add integration test to verify data are read and metrics are emitted
correctly.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #45485 from chaoqin-li1123/python_stream_read.
Authored-by: Chaoqin Li <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
.../streaming/python_streaming_source_runner.py | 2 +-
python/pyspark/sql/worker/plan_data_source_read.py | 75 +++++----
.../v2/python/PythonMicroBatchStream.scala | 16 +-
.../datasources/v2/python/PythonScan.scala | 3 +-
.../datasources/v2/python/PythonTable.scala | 4 +-
.../v2/python/UserDefinedPythonDataSource.scala | 16 +-
.../spark/sql/streaming/DataStreamReader.scala | 5 +
.../python/PythonStreamingDataSourceSuite.scala | 182 ++++++++++++++++++---
8 files changed, 238 insertions(+), 65 deletions(-)
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index 8dbac431a8ba..512191866a16 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -141,7 +141,7 @@ def main(infile: IO, outfile: IO) -> None:
error_msg = "data source {} throw exception:
{}".format(data_source.name, e)
raise PySparkRuntimeError(
error_class="PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR",
- message_parameters={"error": error_msg},
+ message_parameters={"msg": error_msg},
)
finally:
reader.stop()
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py
b/python/pyspark/sql/worker/plan_data_source_read.py
index 8f1fc1e59a61..3e5105996ed4 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -25,6 +25,7 @@ from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import (
+ read_bool,
read_int,
write_int,
SpecialLengths,
@@ -127,33 +128,14 @@ def main(infile: IO, outfile: IO) -> None:
f"'{max_arrow_batch_size}'"
)
- # Instantiate data source reader.
- reader = data_source.reader(schema=schema)
+ is_streaming = read_bool(infile)
- # Get the partitions if any.
- try:
- partitions = reader.partitions()
- if not isinstance(partitions, list):
- raise PySparkRuntimeError(
- error_class="DATA_SOURCE_TYPE_MISMATCH",
- message_parameters={
- "expected": "'partitions' to return a list",
- "actual": f"'{type(partitions).__name__}'",
- },
- )
- if not all(isinstance(p, InputPartition) for p in partitions):
- partition_types = ", ".join([f"'{type(p).__name__}'" for p in
partitions])
- raise PySparkRuntimeError(
- error_class="DATA_SOURCE_TYPE_MISMATCH",
- message_parameters={
- "expected": "all elements in 'partitions' to be of
type 'InputPartition'",
- "actual": partition_types,
- },
- )
- if len(partitions) == 0:
- partitions = [None] # type: ignore
- except NotImplementedError:
- partitions = [None] # type: ignore
+ # Instantiate data source reader.
+ reader = (
+ data_source.streamReader(schema=schema)
+ if is_streaming
+ else data_source.reader(schema=schema)
+ )
# Wrap the data source read logic in an mapInArrow UDF.
import pyarrow as pa
@@ -195,7 +177,7 @@ def main(infile: IO, outfile: IO) -> None:
f"but found '{type(partition).__name__}'."
)
- output_iter = reader.read(partition) # type: ignore[arg-type]
+ output_iter = reader.read(partition) # type: ignore[attr-defined]
# Validate the output iterator.
if not isinstance(output_iter, Iterator):
@@ -264,11 +246,40 @@ def main(infile: IO, outfile: IO) -> None:
command = (data_source_read_func, return_type)
pickleSer._write_with_length(command, outfile)
- # Return the serialized partition values.
- write_int(len(partitions), outfile)
- for partition in partitions:
- pickleSer._write_with_length(partition, outfile)
-
+ if not is_streaming:
+ # The partitioning of python batch source read is determined
before query execution.
+ try:
+ partitions = reader.partitions() # type: ignore[attr-defined]
+ if not isinstance(partitions, list):
+ raise PySparkRuntimeError(
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": "'partitions' to return a list",
+ "actual": f"'{type(partitions).__name__}'",
+ },
+ )
+ if not all(isinstance(p, InputPartition) for p in partitions):
+ partition_types = ", ".join([f"'{type(p).__name__}'" for p
in partitions])
+ raise PySparkRuntimeError(
+ error_class="DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": "elements in 'partitions' to be of
type 'InputPartition'",
+ "actual": partition_types,
+ },
+ )
+ if len(partitions) == 0:
+ partitions = [None]
+ except NotImplementedError:
+ partitions = [None]
+
+ # Return the serialized partition values.
+ write_int(len(partitions), outfile)
+ for partition in partitions:
+ pickleSer._write_with_length(partition, outfile)
+ else:
+ # Send an empty list of partition for stream reader because
partitions are planned
+ # in each microbatch during query execution.
+ write_int(0, outfile)
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
index 4e77f33c24f0..71e6c29bc299 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
@@ -25,8 +25,6 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class PythonStreamingSourceOffset(json: String) extends Offset
-case class PythonStreamingSourcePartition(partition: Array[Byte]) extends
InputPartition
-
class PythonMicroBatchStream(
ds: PythonDataSourceV2,
shortName: String,
@@ -47,12 +45,20 @@ class PythonMicroBatchStream(
override def planInputPartitions(start: Offset, end: Offset):
Array[InputPartition] = {
runner.partitions(start.asInstanceOf[PythonStreamingSourceOffset].json,
-
end.asInstanceOf[PythonStreamingSourceOffset].json).map(PythonStreamingSourcePartition(_))
+ end.asInstanceOf[PythonStreamingSourceOffset].json)
+ .zipWithIndex.map(p => PythonInputPartition(p._2, p._1))
+ }
+
+ private lazy val readInfo: PythonDataSourceReadInfo = {
+ ds.source.createReadInfoInPython(
+ ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
+ outputSchema,
+ isStreaming = true)
}
override def createReaderFactory(): PartitionReaderFactory = {
- // TODO(SPARK-47107): fill in the implementation.
- null
+ new PythonPartitionReaderFactory(
+ ds.source, readInfo.func, outputSchema, None)
}
override def commit(end: Offset): Unit = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
index bcddf66fc161..8fefc8b144a1 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
@@ -53,7 +53,8 @@ class PythonBatch(
private lazy val infoInPython: PythonDataSourceReadInfo = {
ds.source.createReadInfoInPython(
ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
- outputSchema)
+ outputSchema,
+ isStreaming = false)
}
override def planInputPartitions(): Array[InputPartition] =
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
index 0476650a60bf..f633e601f424 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonTable.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.v2.python
import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite,
Table, TableCapability}
-import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ,
BATCH_WRITE, STREAMING_WRITE, TRUNCATE}
+import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ,
BATCH_WRITE, MICRO_BATCH_READ, STREAMING_WRITE, TRUNCATE}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
import org.apache.spark.sql.types.StructType
@@ -32,7 +32,7 @@ class PythonTable(
override def name(): String = shortName
override def capabilities(): java.util.Set[TableCapability] =
java.util.EnumSet.of(
- BATCH_READ, BATCH_WRITE, STREAMING_WRITE, TRUNCATE)
+ BATCH_READ, BATCH_WRITE, MICRO_BATCH_READ, STREAMING_WRITE, TRUNCATE)
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder
= {
new PythonScanBuilder(ds, shortName, outputSchema, options)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index 0586d1fd4bc1..241d8087fc3c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -67,11 +67,13 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
*/
def createReadInfoInPython(
pythonResult: PythonDataSourceCreationResult,
- outputSchema: StructType): PythonDataSourceReadInfo = {
+ outputSchema: StructType,
+ isStreaming: Boolean): PythonDataSourceReadInfo = {
new UserDefinedPythonDataSourceReadRunner(
createPythonFunction(pythonResult.dataSource),
UserDefinedPythonDataSource.readInputSchema,
- outputSchema).runInPython()
+ outputSchema,
+ isStreaming).runInPython()
}
/**
@@ -312,7 +314,8 @@ case class PythonDataSourceReadInfo(
private class UserDefinedPythonDataSourceReadRunner(
func: PythonFunction,
inputSchema: StructType,
- outputSchema: StructType) extends
PythonPlannerRunner[PythonDataSourceReadInfo](func) {
+ outputSchema: StructType,
+ isStreaming: Boolean) extends
PythonPlannerRunner[PythonDataSourceReadInfo](func) {
// See the logic in `pyspark.sql.worker.plan_data_source_read.py`.
override val workerModule = "pyspark.sql.worker.plan_data_source_read"
@@ -329,6 +332,8 @@ private class UserDefinedPythonDataSourceReadRunner(
// Send configurations
dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch)
+
+ dataOut.writeBoolean(isStreaming)
}
override protected def receiveFromPython(dataIn: DataInputStream):
PythonDataSourceReadInfo = {
@@ -346,6 +351,11 @@ private class UserDefinedPythonDataSourceReadRunner(
// Receive the list of partitions, if any.
val pickledPartitions = ArrayBuffer.empty[Array[Byte]]
val numPartitions = dataIn.readInt()
+ if (numPartitions == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+ val msg = PythonWorkerUtils.readUTF(dataIn)
+ throw QueryCompilationErrors.pythonDataSourceError(
+ action = "plan", tpe = "read", msg = msg)
+ }
for (_ <- 0 until numPartitions) {
val pickledPartition: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
pickledPartitions.append(pickledPartition)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 6e24e14fb1eb..24d769fc8fc8 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import
org.apache.spark.sql.execution.datasources.json.JsonUtils.checkJsonSchema
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils,
FileDataSourceV2}
+import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2
import org.apache.spark.sql.execution.datasources.xml.XmlUtils.checkXmlSchema
import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.sources.StreamSourceProvider
@@ -178,6 +179,10 @@ final class DataStreamReader private[sql](sparkSession:
SparkSession) extends Lo
val finalOptions = sessionOptions.filter { case (k, _) =>
!optionsWithPath.contains(k) } ++
optionsWithPath.originalMap
val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava)
+ provider match {
+ case p: PythonDataSourceV2 => p.setShortName(source)
+ case _ =>
+ }
val table = DataSourceV2Utils.getTableFromProvider(provider,
dsOptions, userSpecifiedSchema)
import
org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
table match {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
index 42eaa492be73..6f4bd1888fbb 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.python
import java.io.File
+import java.util.concurrent.CountDownLatch
import scala.concurrent.duration._
@@ -24,7 +25,8 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import
org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource,
shouldTestPandasUDFs}
import
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2,
PythonMicroBatchStream, PythonStreamingSourceOffset}
-import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.execution.streaming.{MemoryStream,
ProcessingTimeTrigger}
+import org.apache.spark.sql.streaming.StreamingQueryException
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -39,10 +41,12 @@ class PythonStreamingDataSourceSuite extends
PythonDataSourceSuiteBase {
|from pyspark.sql.datasource import DataSourceStreamReader,
InputPartition
|
|class SimpleDataStreamReader(DataSourceStreamReader):
+ | current = 0
| def initialOffset(self):
| return {"offset": {"partition-1": 0}}
| def latestOffset(self):
- | return {"offset": {"partition-1": 2}}
+ | self.current += 2
+ | return {"offset": {"partition-1": self.current}}
| def partitions(self, start: dict, end: dict):
| start_index = start["offset"]["partition-1"]
| end_index = end["offset"]["partition-1"]
@@ -50,9 +54,7 @@ class PythonStreamingDataSourceSuite extends
PythonDataSourceSuiteBase {
| def commit(self, end: dict):
| 1 + 2
| def read(self, partition):
- | yield (0, partition.value)
- | yield (1, partition.value)
- | yield (2, partition.value)
+ | yield (partition.value,)
|""".stripMargin
protected def errorDataStreamReaderScript: String =
@@ -110,7 +112,7 @@ class PythonStreamingDataSourceSuite extends
PythonDataSourceSuiteBase {
private val errorDataSourceName = "ErrorDataSource"
- test("simple data stream source") {
+ test("Test PythonMicroBatchStream") {
assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
@@ -130,40 +132,178 @@ class PythonStreamingDataSourceSuite extends
PythonDataSourceSuiteBase {
val stream = new PythonMicroBatchStream(
pythonDs, dataSourceName, inputSchema, CaseInsensitiveStringMap.empty())
- val initialOffset = stream.initialOffset()
- assert(initialOffset.json == "{\"offset\": {\"partition-1\": 0}}")
- for (_ <- 1 to 50) {
- val offset = stream.latestOffset()
- assert(offset.json == "{\"offset\": {\"partition-1\": 2}}")
- assert(stream.planInputPartitions(initialOffset, offset).size == 2)
- stream.commit(offset)
+ var startOffset = stream.initialOffset()
+ assert(startOffset.json == "{\"offset\": {\"partition-1\": 0}}")
+ for (i <- 1 to 50) {
+ val endOffset = stream.latestOffset()
+ assert(endOffset.json == s"""{"offset": {"partition-1": ${2 * i}}}""")
+ assert(stream.planInputPartitions(startOffset, endOffset).size == 2)
+ stream.commit(endOffset)
+ startOffset = endOffset
}
stream.stop()
}
+ test("Read from simple data stream source") {
+ assume(shouldTestPandasUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource
+ |$simpleDataStreamReaderScript
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "id INT"
+ | def streamReader(self, schema):
+ | return SimpleDataStreamReader()
+ |""".stripMargin
+
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+ val df = spark.readStream.format(dataSourceName).load()
+
+ val stopSignal = new CountDownLatch(1)
+
+ val q = df.writeStream.foreachBatch((df: DataFrame, batchId: Long) => {
+ // checkAnswer may materialize the dataframe more than once
+ // Cache here to make sure the numInputRows metrics is consistent.
+ df.cache()
+ checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+ if (batchId > 30) stopSignal.countDown()
+ }).trigger(ProcessingTimeTrigger(0)).start()
+ stopSignal.await()
+ assert(q.recentProgress.forall(_.numInputRows == 2))
+ q.stop()
+ q.awaitTermination()
+ }
+
+ test("Streaming data source read with custom partitions") {
+ assume(shouldTestPandasUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource,
DataSourceStreamReader, InputPartition
+ |class RangePartition(InputPartition):
+ | def __init__(self, start, end):
+ | self.start = start
+ | self.end = end
+ |
+ |class SimpleDataStreamReader(DataSourceStreamReader):
+ | current = 0
+ | def initialOffset(self):
+ | return {"offset": 0}
+ | def latestOffset(self):
+ | self.current += 2
+ | return {"offset": self.current}
+ | def partitions(self, start: dict, end: dict):
+ | return [RangePartition(start["offset"], end["offset"])]
+ | def commit(self, end: dict):
+ | 1 + 2
+ | def read(self, partition: RangePartition):
+ | start, end = partition.start, partition.end
+ | for i in range(start, end):
+ | yield (i, )
+ |
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "id INT"
+ |
+ | def streamReader(self, schema):
+ | return SimpleDataStreamReader()
+ |""".stripMargin
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+ val df = spark.readStream.format(dataSourceName).load()
+
+ val stopSignal = new CountDownLatch(1)
+
+ val q = df.writeStream.foreachBatch((df: DataFrame, batchId: Long) => {
+ // checkAnswer may materialize the dataframe more than once
+ // Cache here to make sure the numInputRows metrics is consistent.
+ df.cache()
+ checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+ if (batchId > 30) stopSignal.countDown()
+ }).trigger(ProcessingTimeTrigger(0)).start()
+ stopSignal.await()
+ assert(q.recentProgress.forall(_.numInputRows == 2))
+ q.stop()
+ q.awaitTermination()
+ }
+
test("Error creating stream reader") {
assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource
|class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "id INT"
| def streamReader(self, schema):
| raise Exception("error creating stream reader")
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(
name = dataSourceName, pythonScript = dataSourceScript)
spark.dataSource.registerPython(dataSourceName, dataSource)
- val pythonDs = new PythonDataSourceV2
- pythonDs.setShortName("SimpleDataSource")
- val inputSchema = StructType.fromDDL("input BINARY")
- val err = intercept[AnalysisException] {
- new PythonMicroBatchStream(
- pythonDs, dataSourceName, inputSchema,
CaseInsensitiveStringMap.empty())
+
+ val err = intercept[StreamingQueryException] {
+ val q = spark.readStream.format(dataSourceName).load()
+ .writeStream.format("console").start()
+ q.awaitTermination()
}
- assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR")
+ assert(err.getErrorClass == "STREAM_FAILED")
assert(err.getMessage.contains("error creating stream reader"))
}
+ test("Streaming data source read error") {
+ assume(shouldTestPandasUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource,
DataSourceStreamReader, InputPartition
+ |class RangePartition(InputPartition):
+ | def __init__(self, start, end):
+ | self.start = start
+ | self.end = end
+ |
+ |class SimpleDataStreamReader(DataSourceStreamReader):
+ | current = 0
+ | def initialOffset(self):
+ | return {"offset": "0"}
+ | def latestOffset(self):
+ | self.current += 2
+ | return {"offset": str(self.current)}
+ | def partitions(self, start: dict, end: dict):
+ | return [RangePartition(int(start["offset"]),
int(end["offset"]))]
+ | def commit(self, end: dict):
+ | 1 + 2
+ | def read(self, partition: RangePartition):
+ | raise Exception("error reading data")
+ |
+ |
+ |class $dataSourceName(DataSource):
+ | def schema(self) -> str:
+ | return "id INT"
+ |
+ | def streamReader(self, schema):
+ | return SimpleDataStreamReader()
+ |""".stripMargin
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+ val df = spark.readStream.format(dataSourceName).load()
+
+ val err = intercept[StreamingQueryException] {
+ val q = df.writeStream.foreachBatch((df: DataFrame, _: Long) => {
+ df.count()
+ ()
+ }).start()
+ q.awaitTermination()
+ }
+ assert(err.getMessage.contains("error reading data"))
+ }
+
+
test("Method not implemented in stream reader") {
assume(shouldTestPandasUDFs)
val dataSourceScript =
@@ -237,7 +377,7 @@ class PythonStreamingDataSourceSuite extends
PythonDataSourceSuiteBase {
spark.dataSource.registerPython(errorDataSourceName, dataSource)
val pythonDs = new PythonDataSourceV2
pythonDs.setShortName("ErrorDataSource")
- val offset = PythonStreamingSourceOffset("{\"offset\": \"2\"}")
+ val offset = PythonStreamingSourceOffset("{\"offset\": 2}")
def testMicroBatchStreamError(action: String, msg: String)
(func: PythonMicroBatchStream => Unit): Unit
= {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]