HeartSaVioR commented on code in PR #45485:
URL: https://github.com/apache/spark/pull/45485#discussion_r1532121734


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala:
##########
@@ -87,40 +90,178 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
     val stream = new PythonMicroBatchStream(
       pythonDs, dataSourceName, inputSchema, CaseInsensitiveStringMap.empty())
 
-    val initialOffset = stream.initialOffset()
-    assert(initialOffset.json == "{\"offset\": {\"partition-1\": 0}}")
-    for (_ <- 1 to 50) {
-      val offset = stream.latestOffset()
-      assert(offset.json == "{\"offset\": {\"partition-1\": 2}}")
-      assert(stream.planInputPartitions(initialOffset, offset).size == 2)
-      stream.commit(offset)
+    var startOffset = stream.initialOffset()
+    assert(startOffset.json == "{\"offset\": {\"partition-1\": 0}}")
+    for (i <- 1 to 50) {
+      val endOffset = stream.latestOffset()
+      assert(endOffset.json == s"""{"offset": {"partition-1": ${2 * i}}}""")
+      assert(stream.planInputPartitions(startOffset, endOffset).size == 2)
+      stream.commit(endOffset)
+      startOffset = endOffset
     }
     stream.stop()
   }
 
+  test("Read from simple data stream source") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$simpleDataStreamReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |    def streamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    val df = spark.readStream.format(dataSourceName).load()
+
+    val stopSignal = new CountDownLatch(1)
+
+    val q = df.writeStream.foreachBatch((df: DataFrame, batchId: Long) => {
+      // checkAnswer may materialize the dataframe more than once
+      // Cache here to make sure the numInputRows metrics is consistent.
+      df.cache()
+      checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+      if (batchId > 30) stopSignal.countDown()
+    }).start()

Review Comment:
   nit: probably good to set processing trigger with 0 delay explicitly, to 
ensure the overall test time would be limited.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala:
##########
@@ -87,40 +90,178 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
     val stream = new PythonMicroBatchStream(
       pythonDs, dataSourceName, inputSchema, CaseInsensitiveStringMap.empty())
 
-    val initialOffset = stream.initialOffset()
-    assert(initialOffset.json == "{\"offset\": {\"partition-1\": 0}}")
-    for (_ <- 1 to 50) {
-      val offset = stream.latestOffset()
-      assert(offset.json == "{\"offset\": {\"partition-1\": 2}}")
-      assert(stream.planInputPartitions(initialOffset, offset).size == 2)
-      stream.commit(offset)
+    var startOffset = stream.initialOffset()
+    assert(startOffset.json == "{\"offset\": {\"partition-1\": 0}}")
+    for (i <- 1 to 50) {
+      val endOffset = stream.latestOffset()
+      assert(endOffset.json == s"""{"offset": {"partition-1": ${2 * i}}}""")
+      assert(stream.planInputPartitions(startOffset, endOffset).size == 2)
+      stream.commit(endOffset)
+      startOffset = endOffset
     }
     stream.stop()
   }
 
+  test("Read from simple data stream source") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$simpleDataStreamReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |    def streamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    val df = spark.readStream.format(dataSourceName).load()
+
+    val stopSignal = new CountDownLatch(1)
+
+    val q = df.writeStream.foreachBatch((df: DataFrame, batchId: Long) => {
+      // checkAnswer may materialize the dataframe more than once
+      // Cache here to make sure the numInputRows metrics is consistent.
+      df.cache()
+      checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+      if (batchId > 30) stopSignal.countDown()
+    }).start()
+    stopSignal.await()
+    assert(q.recentProgress.forall(_.numInputRows == 2))
+    q.stop()
+    q.awaitTermination()
+  }
+
+  test("Streaming data source read with custom partitions") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource, 
DataSourceStreamReader, InputPartition
+         |class RangePartition(InputPartition):
+         |    def __init__(self, start, end):
+         |        self.start = start
+         |        self.end = end
+         |
+         |class SimpleDataStreamReader(DataSourceStreamReader):
+         |    current = 0
+         |    def initialOffset(self):
+         |        return {"offset": 0}
+         |    def latestOffset(self):
+         |        self.current += 2
+         |        return {"offset": self.current}
+         |    def partitions(self, start: dict, end: dict):
+         |        return [RangePartition(start["offset"], end["offset"])]
+         |    def commit(self, end: dict):
+         |        1 + 2
+         |    def read(self, partition: RangePartition):
+         |        start, end = partition.start, partition.end
+         |        for i in range(start, end):
+         |            yield (i, )
+         |
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |
+         |    def streamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    val df = spark.readStream.format(dataSourceName).load()
+
+    val stopSignal = new CountDownLatch(1)
+
+    val q = df.writeStream.foreachBatch((df: DataFrame, batchId: Long) => {
+      // checkAnswer may materialize the dataframe more than once
+      // Cache here to make sure the numInputRows metrics is consistent.
+      df.cache()
+      checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+      if (batchId > 30) stopSignal.countDown()
+    }).start()

Review Comment:
   ditto



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to