HeartSaVioR commented on code in PR #45023:
URL: https://github.com/apache/spark/pull/45023#discussion_r1489259159


##########
python/pyspark/sql/datasource.py:
##########
@@ -298,6 +320,104 @@ def read(self, partition: InputPartition) -> 
Iterator[Union[Tuple, Row]]:
         ...
 
 
+class DataSourceStreamReader(ABC):
+    """
+    A base class for streaming data source readers. Data source stream readers 
are responsible
+    for outputting data from a streaming data source.
+
+    .. versionadded: 4.0.0
+    """
+
+    def initialOffset(self) -> dict:
+        """
+        Return the initial offset of the streaming data source.
+        A new streaming query starts reading data from the initial offset.
+        If Spark is restarting an existing query, it will restart from the 
check-pointed offset
+        rather than the initial one.
+
+        Returns
+        -------
+        dict
+            A dict whose key and values are str type.

Review Comment:
   I just realized this is too limited - there are data sources which has an 
offset structure as complex model, e.g. offset model consists of a list 
containing case class instance as element. It probably couldn't be bound to 
flattened dictionary.
   
   Shall we revisit the design for the structure of offset? Probably the 
closest approach from proposed change to do is to simply allow dict, with 
noticing that we are serializing dict to json via following approach underneath 
(hence the dict has to be compatible). 
   https://docs.python.org/3/library/json.html#py-to-json-table
   
   Or just introduce an interface for Offset which is expected to handle json 
serde by itself. We could probably document the example; how to implement json 
serde for the complex type of offset.



##########
python/pyspark/sql/streaming/python_streaming_source_runner.py:
##########
@@ -0,0 +1,178 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+import functools
+import json
+from itertools import islice
+from typing import IO, List, Iterator, Iterable
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    read_int,
+    write_int,
+    write_with_length,
+    SpecialLengths,
+)
+from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, 
LocalDataToArrowConversion
+from pyspark.sql.datasource import DataSource, InputPartition
+from pyspark.sql.pandas.types import to_arrow_schema
+from pyspark.sql.types import (
+    _parse_datatype_json_string,
+    BinaryType,
+    StructType,
+)
+from pyspark.util import handle_worker_exception
+from pyspark.worker_util import (
+    check_python_version,
+    read_command,
+    pickleSer,
+    send_accumulator_updates,
+    setup_broadcasts,
+    setup_memory_limits,
+    setup_spark_files,
+    utf8_deserializer,
+)
+
+initial_offset_func_id = 884
+latest_offset_func_id = 885
+partitions_func_id = 886
+commit_func_id = 887
+
+
+def initial_offset_func(reader, outfile):
+    offset = reader.initialOffset()
+    write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
+def latest_offset_func(reader, outfile):
+    offset = reader.latestOffset()
+    write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
+def partitions_func(reader, infile, outfile):
+    start_offset = json.loads(utf8_deserializer.loads(infile))
+    end_offset = json.loads(utf8_deserializer.loads(infile))
+    partitions = reader.partitions(start_offset, end_offset)
+    # Return the serialized partition values.
+    write_int(len(partitions), outfile)
+    for partition in partitions:
+        pickleSer._write_with_length(partition, outfile)
+
+
+def commit_func(reader, infile, outfile):
+    end_offset = json.loads(utf8_deserializer.loads(infile))
+    reader.commit(end_offset)
+    write_int(0, outfile)
+
+
+def main(infile: IO, outfile: IO) -> None:
+    try:
+        check_python_version(infile)
+        setup_spark_files(infile)
+
+        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
+        setup_memory_limits(memory_limit_mb)
+
+        _accumulatorRegistry.clear()
+
+        # Receive the data source instance.
+        data_source = read_command(pickleSer, infile)
+
+        if not isinstance(data_source, DataSource):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "a Python data source instance of type 
'DataSource'",
+                    "actual": f"'{type(data_source).__name__}'",
+                },
+            )
+
+        # Receive the data source output schema.
+        schema_json = utf8_deserializer.loads(infile)
+        schema = _parse_datatype_json_string(schema_json)
+        if not isinstance(schema, StructType):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "an output schema of type 'StructType'",
+                    "actual": f"'{type(schema).__name__}'",
+                },
+            )
+
+        # Receive the configuration values.
+        max_arrow_batch_size = read_int(infile)
+        assert max_arrow_batch_size > 0, (
+            "The maximum arrow batch size should be greater than 0, but got "
+            f"'{max_arrow_batch_size}'"
+        )
+
+        # Instantiate data source reader.
+        try:
+            reader = data_source.streamReader(schema=schema)
+            # Initialization succeed.
+            write_int(0, outfile)
+            outfile.flush()
+
+            # handle method call from socket
+            while True:
+                func_id = read_int(infile)
+                if func_id == initial_offset_func_id:
+                    initial_offset_func(reader, outfile)
+                elif func_id == latest_offset_func_id:
+                    latest_offset_func(reader, outfile)
+                elif func_id == partitions_func_id:
+                    partitions_func(reader, infile, outfile)
+                elif func_id == commit_func_id:
+                    commit_func(reader, infile, outfile)
+                outfile.flush()
+
+        except NotImplementedError:
+            raise PySparkRuntimeError(
+                error_class="PYTHON_DATA_SOURCE_METHOD_NOT_IMPLEMENTED",
+                message_parameters={"type": "reader", "method": "reader"},
+            )
+        except Exception as e:

Review Comment:
   Isn't this too broad compared to what we record as? If we want to classify 
it as "data source creation error", I think we should narrow the scope to the 
call of streamReader, not beyond.
   
   Please look back how batch side does this.



##########
python/pyspark/sql/streaming/python_streaming_source_runner.py:
##########
@@ -0,0 +1,178 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+import functools
+import json
+from itertools import islice
+from typing import IO, List, Iterator, Iterable
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    read_int,
+    write_int,
+    write_with_length,
+    SpecialLengths,
+)
+from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, 
LocalDataToArrowConversion
+from pyspark.sql.datasource import DataSource, InputPartition
+from pyspark.sql.pandas.types import to_arrow_schema
+from pyspark.sql.types import (
+    _parse_datatype_json_string,
+    BinaryType,
+    StructType,
+)
+from pyspark.util import handle_worker_exception
+from pyspark.worker_util import (
+    check_python_version,
+    read_command,
+    pickleSer,
+    send_accumulator_updates,
+    setup_broadcasts,
+    setup_memory_limits,
+    setup_spark_files,
+    utf8_deserializer,
+)
+
+initial_offset_func_id = 884
+latest_offset_func_id = 885
+partitions_func_id = 886
+commit_func_id = 887
+
+
+def initial_offset_func(reader, outfile):
+    offset = reader.initialOffset()
+    write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
+def latest_offset_func(reader, outfile):
+    offset = reader.latestOffset()
+    write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
+def partitions_func(reader, infile, outfile):
+    start_offset = json.loads(utf8_deserializer.loads(infile))
+    end_offset = json.loads(utf8_deserializer.loads(infile))
+    partitions = reader.partitions(start_offset, end_offset)
+    # Return the serialized partition values.
+    write_int(len(partitions), outfile)
+    for partition in partitions:
+        pickleSer._write_with_length(partition, outfile)
+
+
+def commit_func(reader, infile, outfile):
+    end_offset = json.loads(utf8_deserializer.loads(infile))
+    reader.commit(end_offset)
+    write_int(0, outfile)
+
+
+def main(infile: IO, outfile: IO) -> None:
+    try:
+        check_python_version(infile)
+        setup_spark_files(infile)
+
+        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
+        setup_memory_limits(memory_limit_mb)
+
+        _accumulatorRegistry.clear()
+
+        # Receive the data source instance.
+        data_source = read_command(pickleSer, infile)
+
+        if not isinstance(data_source, DataSource):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "a Python data source instance of type 
'DataSource'",
+                    "actual": f"'{type(data_source).__name__}'",
+                },
+            )
+
+        # Receive the data source output schema.
+        schema_json = utf8_deserializer.loads(infile)
+        schema = _parse_datatype_json_string(schema_json)
+        if not isinstance(schema, StructType):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "an output schema of type 'StructType'",
+                    "actual": f"'{type(schema).__name__}'",
+                },
+            )
+
+        # Receive the configuration values.

Review Comment:
   This does not seem to be used - will this be referred in later part?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.api.python.{PythonFunction, PythonWorker, 
PythonWorkerFactory, PythonWorkerUtils, SpecialLengths}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, 
PYTHON_USE_DAEMON}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+object PythonStreamingSourceRunner {
+  val initialOffsetFuncId = 884
+  val latestOffsetFuncId = 885
+  val partitionsFuncId = 886
+  val commitFuncId = 887
+}
+
+/**
+ * This class is a proxy to invoke methods in Python DataSourceStreamReader 
from JVM.
+ * A runner spawns a python worker process. In the main function, set up 
communication
+ * between JVM and python process through socket and create a 
DataSourceStreamReader instance.
+ * In an infinite loop, the python worker process poll information(function 
name and parameters)
+ * from the socket, invoke the corresponding method of StreamReader and send 
return value to JVM.
+ */
+class PythonStreamingSourceRunner(
+    func: PythonFunction,
+    outputSchema: StructType) extends Logging  {
+  val workerModule = "pyspark.sql.streaming.python_streaming_source_runner"
+
+  private val conf = SparkEnv.get.conf
+  protected val bufferSize: Int = conf.get(BUFFER_SIZE)
+  protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+
+  private val envVars: java.util.Map[String, String] = func.envVars
+  private val pythonExec: String = func.pythonExec
+  private var pythonWorker: Option[PythonWorker] = None
+  private var pythonWorkerFactory: Option[PythonWorkerFactory] = None
+  protected val pythonVer: String = func.pythonVer
+
+  private var dataOut: DataOutputStream = null
+  private var dataIn: DataInputStream = null
+
+  import PythonStreamingSourceRunner._
+
+  /**
+   * Initializes the Python worker for running the streaming source.
+   */
+  def init(): Unit = {
+    logInfo(s"Initializing Python runner pythonExec: $pythonExec")
+    val env = SparkEnv.get
+
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
+    envVars.put("SPARK_LOCAL_DIRS", localdir)
+
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    val prevConf = conf.get(PYTHON_USE_DAEMON)

Review Comment:
   @HyukjinKwon Do you anticipate any issue with temporary changing the conf 
value and reverting back?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala:
##########
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
+import 
org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, 
shouldTestPandasUDFs}
+import 
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, 
PythonMicroBatchStream, PythonStreamingSourceOffset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase {
+
+  protected def simpleDataStreamReaderScript: String =
+    """
+      |from pyspark.sql.datasource import DataSourceStreamReader, 
InputPartition
+      |
+      |class SimpleDataStreamReader(DataSourceStreamReader):
+      |    def initialOffset(self):
+      |        return {"0": "2"}
+      |    def latestOffset(self):
+      |        return {"0": "2"}
+      |    def partitions(self, start: dict, end: dict):
+      |        return [InputPartition(i) for i in range(int(start["0"]))]
+      |    def commit(self, end: dict):
+      |        1 + 2
+      |    def read(self, partition):
+      |        yield (0, partition.value)
+      |        yield (1, partition.value)
+      |        yield (2, partition.value)
+      |""".stripMargin
+
+  protected def errorDataStreamReaderScript: String =
+    """
+      |from pyspark.sql.datasource import DataSourceStreamReader, 
InputPartition
+      |
+      |class ErrorDataStreamReader(DataSourceStreamReader):
+      |    def initialOffset(self):
+      |        raise Exception("error reading initial offset")
+      |    def latestOffset(self):
+      |        raise Exception("error reading latest offset")
+      |    def partitions(self, start: dict, end: dict):
+      |        raise Exception("error planning partitions")
+      |    def commit(self, end: dict):
+      |        raise Exception("error committing offset")
+      |    def read(self, partition):
+      |        yield (0, partition.value)
+      |        yield (1, partition.value)
+      |        yield (2, partition.value)
+      |""".stripMargin
+
+  private val errorDataSourceName = "ErrorDataSource"
+
+  test("simple data stream source") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$simpleDataStreamReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def streamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+    val inputSchema = StructType.fromDDL("input BINARY")
+
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("SimpleDataSource")
+    val stream = new PythonMicroBatchStream(
+      pythonDs, dataSourceName, inputSchema, CaseInsensitiveStringMap.empty())
+
+    val initialOffset = stream.initialOffset()
+    assert(initialOffset.json == "{\"0\": \"2\"}")
+    for (_ <- 1 to 50) {
+      val offset = stream.latestOffset()
+      assert(offset.json == "{\"0\": \"2\"}")
+      assert(stream.planInputPartitions(offset, offset).size == 2)
+      stream.commit(offset)
+    }
+    stream.stop()
+  }
+
+  test("Error creating stream reader") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |class $dataSourceName(DataSource):
+         |    def streamReader(self, schema):
+         |        raise Exception("error creating stream reader")
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(
+      name = dataSourceName, pythonScript = dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("SimpleDataSource")
+    val inputSchema = StructType.fromDDL("input BINARY")
+    val err = intercept[AnalysisException] {
+      new PythonMicroBatchStream(
+        pythonDs, dataSourceName, inputSchema, 
CaseInsensitiveStringMap.empty())
+    }
+    assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR")

Review Comment:
   If the error is bound to error class framework, you need to verify with 
checkError.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala:
##########
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
+import 
org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, 
shouldTestPandasUDFs}
+import 
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, 
PythonMicroBatchStream, PythonStreamingSourceOffset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase {
+
+  protected def simpleDataStreamReaderScript: String =
+    """
+      |from pyspark.sql.datasource import DataSourceStreamReader, 
InputPartition
+      |
+      |class SimpleDataStreamReader(DataSourceStreamReader):
+      |    def initialOffset(self):
+      |        return {"0": "2"}
+      |    def latestOffset(self):
+      |        return {"0": "2"}
+      |    def partitions(self, start: dict, end: dict):
+      |        return [InputPartition(i) for i in range(int(start["0"]))]

Review Comment:
   What is `start["0"]`?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala:
##########
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
+import 
org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, 
shouldTestPandasUDFs}
+import 
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, 
PythonMicroBatchStream, PythonStreamingSourceOffset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase {
+
+  protected def simpleDataStreamReaderScript: String =
+    """
+      |from pyspark.sql.datasource import DataSourceStreamReader, 
InputPartition
+      |
+      |class SimpleDataStreamReader(DataSourceStreamReader):
+      |    def initialOffset(self):
+      |        return {"0": "2"}
+      |    def latestOffset(self):
+      |        return {"0": "2"}
+      |    def partitions(self, start: dict, end: dict):
+      |        return [InputPartition(i) for i in range(int(start["0"]))]
+      |    def commit(self, end: dict):
+      |        1 + 2
+      |    def read(self, partition):
+      |        yield (0, partition.value)
+      |        yield (1, partition.value)
+      |        yield (2, partition.value)
+      |""".stripMargin
+
+  protected def errorDataStreamReaderScript: String =
+    """
+      |from pyspark.sql.datasource import DataSourceStreamReader, 
InputPartition
+      |
+      |class ErrorDataStreamReader(DataSourceStreamReader):
+      |    def initialOffset(self):
+      |        raise Exception("error reading initial offset")
+      |    def latestOffset(self):
+      |        raise Exception("error reading latest offset")
+      |    def partitions(self, start: dict, end: dict):
+      |        raise Exception("error planning partitions")
+      |    def commit(self, end: dict):
+      |        raise Exception("error committing offset")
+      |    def read(self, partition):
+      |        yield (0, partition.value)
+      |        yield (1, partition.value)
+      |        yield (2, partition.value)
+      |""".stripMargin
+
+  private val errorDataSourceName = "ErrorDataSource"
+
+  test("simple data stream source") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$simpleDataStreamReaderScript
+         |
+         |class $dataSourceName(DataSource):
+         |    def streamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+    val inputSchema = StructType.fromDDL("input BINARY")
+
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("SimpleDataSource")
+    val stream = new PythonMicroBatchStream(
+      pythonDs, dataSourceName, inputSchema, CaseInsensitiveStringMap.empty())
+
+    val initialOffset = stream.initialOffset()
+    assert(initialOffset.json == "{\"0\": \"2\"}")
+    for (_ <- 1 to 50) {
+      val offset = stream.latestOffset()
+      assert(offset.json == "{\"0\": \"2\"}")
+      assert(stream.planInputPartitions(offset, offset).size == 2)
+      stream.commit(offset)
+    }
+    stream.stop()
+  }
+
+  test("Error creating stream reader") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |class $dataSourceName(DataSource):
+         |    def streamReader(self, schema):
+         |        raise Exception("error creating stream reader")
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(
+      name = dataSourceName, pythonScript = dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("SimpleDataSource")
+    val inputSchema = StructType.fromDDL("input BINARY")
+    val err = intercept[AnalysisException] {
+      new PythonMicroBatchStream(
+        pythonDs, dataSourceName, inputSchema, 
CaseInsensitiveStringMap.empty())
+    }
+    assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR")
+    assert(err.getMessage.contains("error creating stream reader"))
+  }
+
+  test("Error in stream reader") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |$errorDataStreamReaderScript
+         |
+         |class $errorDataSourceName(DataSource):
+         |    def streamReader(self, schema):
+         |        return ErrorDataStreamReader()
+         |""".stripMargin
+    val inputSchema = StructType.fromDDL("input BINARY")
+
+    val dataSource = createUserDefinedPythonDataSource(errorDataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(errorDataSourceName, dataSource)
+    val pythonDs = new PythonDataSourceV2
+    pythonDs.setShortName("ErrorDataSource")
+    val offset = PythonStreamingSourceOffset("{\"0\": \"2\"}")
+
+    def testMicroBatchStreamError(msg: String)
+                                 (func: PythonMicroBatchStream => Unit): Unit 
= {
+      val stream = new PythonMicroBatchStream(
+        pythonDs, errorDataSourceName, inputSchema, 
CaseInsensitiveStringMap.empty())
+      val err = intercept[SparkException] {
+        func(stream)
+      }
+      assert(err.getMessage.contains(msg))

Review Comment:
   If the error is bound to error class framework, you need to verify with 
checkError.



##########
python/pyspark/sql/streaming/python_streaming_source_runner.py:
##########
@@ -0,0 +1,178 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+import functools
+import json
+from itertools import islice
+from typing import IO, List, Iterator, Iterable
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    read_int,
+    write_int,
+    write_with_length,
+    SpecialLengths,
+)
+from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, 
LocalDataToArrowConversion
+from pyspark.sql.datasource import DataSource, InputPartition
+from pyspark.sql.pandas.types import to_arrow_schema
+from pyspark.sql.types import (
+    _parse_datatype_json_string,
+    BinaryType,
+    StructType,
+)
+from pyspark.util import handle_worker_exception
+from pyspark.worker_util import (
+    check_python_version,
+    read_command,
+    pickleSer,
+    send_accumulator_updates,
+    setup_broadcasts,
+    setup_memory_limits,
+    setup_spark_files,
+    utf8_deserializer,
+)
+
+initial_offset_func_id = 884
+latest_offset_func_id = 885
+partitions_func_id = 886
+commit_func_id = 887
+
+
+def initial_offset_func(reader, outfile):
+    offset = reader.initialOffset()
+    write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
+def latest_offset_func(reader, outfile):
+    offset = reader.latestOffset()
+    write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
+def partitions_func(reader, infile, outfile):
+    start_offset = json.loads(utf8_deserializer.loads(infile))
+    end_offset = json.loads(utf8_deserializer.loads(infile))
+    partitions = reader.partitions(start_offset, end_offset)
+    # Return the serialized partition values.
+    write_int(len(partitions), outfile)
+    for partition in partitions:
+        pickleSer._write_with_length(partition, outfile)
+
+
+def commit_func(reader, infile, outfile):
+    end_offset = json.loads(utf8_deserializer.loads(infile))
+    reader.commit(end_offset)
+    write_int(0, outfile)
+
+
+def main(infile: IO, outfile: IO) -> None:
+    try:
+        check_python_version(infile)
+        setup_spark_files(infile)
+
+        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
+        setup_memory_limits(memory_limit_mb)
+
+        _accumulatorRegistry.clear()
+
+        # Receive the data source instance.
+        data_source = read_command(pickleSer, infile)
+
+        if not isinstance(data_source, DataSource):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "a Python data source instance of type 
'DataSource'",
+                    "actual": f"'{type(data_source).__name__}'",
+                },
+            )
+
+        # Receive the data source output schema.
+        schema_json = utf8_deserializer.loads(infile)

Review Comment:
   What is the expected behavior if user does not specify the schema and relies 
on the schema of data source (which is mostly applied to the majority of data 
sources)?



##########
python/pyspark/sql/datasource.py:
##########
@@ -159,6 +160,27 @@ def writer(self, schema: StructType, overwrite: bool) -> 
"DataSourceWriter":
             message_parameters={"feature": "writer"},
         )
 
+    def streamReader(self, schema: StructType) -> "DataSourceStreamReader":
+        """
+        Returns a ``DataSourceStreamReader`` instance for reading streaming 
data.
+
+        The implementation is required for readable streaming data sources.
+
+        Parameters
+        ----------
+        schema : StructType
+            The schema of the data to be written.

Review Comment:
   This seems to be copied incorrectly from batch writer. Should be from batch 
reader.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala:
##########
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connector.read.{InputPartition, 
PartitionReaderFactory}
+import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
+import org.apache.spark.sql.execution.python.PythonStreamingSourceRunner
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+case class PythonStreamingSourceOffset(json: String) extends Offset
+
+case class PythonStreamingSourcePartition(partition: Array[Byte]) extends 
InputPartition
+
+class PythonMicroBatchStream(
+    ds: PythonDataSourceV2,
+    shortName: String,
+    outputSchema: StructType,
+    options: CaseInsensitiveStringMap
+  ) extends MicroBatchStream with Logging {
+  private def createDataSourceFunc =
+    ds.source.createPythonFunction(
+      ds.getOrCreateDataSourceInPython(shortName, options, 
Some(outputSchema)).dataSource)
+
+  val runner: PythonStreamingSourceRunner =

Review Comment:
   Does this need to be public?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala:
##########
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connector.read.{InputPartition, 
PartitionReaderFactory}
+import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
+import org.apache.spark.sql.execution.python.PythonStreamingSourceRunner
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+case class PythonStreamingSourceOffset(json: String) extends Offset
+
+case class PythonStreamingSourcePartition(partition: Array[Byte]) extends 
InputPartition
+
+class PythonMicroBatchStream(
+    ds: PythonDataSourceV2,
+    shortName: String,
+    outputSchema: StructType,
+    options: CaseInsensitiveStringMap
+  ) extends MicroBatchStream with Logging {
+  private def createDataSourceFunc =
+    ds.source.createPythonFunction(
+      ds.getOrCreateDataSourceInPython(shortName, options, 
Some(outputSchema)).dataSource)
+
+  val runner: PythonStreamingSourceRunner =
+    new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema)
+  runner.init()
+
+  override def initialOffset(): Offset = 
PythonStreamingSourceOffset(runner.initialOffset())
+
+  override def latestOffset(): Offset = 
PythonStreamingSourceOffset(runner.latestOffset())
+

Review Comment:
   nit: unnecessary two empty lines, remove one



##########
python/pyspark/sql/streaming/python_streaming_source_runner.py:
##########
@@ -0,0 +1,178 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import sys
+import functools
+import json
+from itertools import islice
+from typing import IO, List, Iterator, Iterable
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    read_int,
+    write_int,
+    write_with_length,
+    SpecialLengths,
+)
+from pyspark.sql.connect.conversion import ArrowTableToRowsConversion, 
LocalDataToArrowConversion
+from pyspark.sql.datasource import DataSource, InputPartition
+from pyspark.sql.pandas.types import to_arrow_schema
+from pyspark.sql.types import (
+    _parse_datatype_json_string,
+    BinaryType,
+    StructType,
+)
+from pyspark.util import handle_worker_exception
+from pyspark.worker_util import (
+    check_python_version,
+    read_command,
+    pickleSer,
+    send_accumulator_updates,
+    setup_broadcasts,
+    setup_memory_limits,
+    setup_spark_files,
+    utf8_deserializer,
+)
+
+initial_offset_func_id = 884
+latest_offset_func_id = 885
+partitions_func_id = 886
+commit_func_id = 887
+
+
+def initial_offset_func(reader, outfile):
+    offset = reader.initialOffset()
+    write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
+def latest_offset_func(reader, outfile):
+    offset = reader.latestOffset()
+    write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
+def partitions_func(reader, infile, outfile):
+    start_offset = json.loads(utf8_deserializer.loads(infile))
+    end_offset = json.loads(utf8_deserializer.loads(infile))
+    partitions = reader.partitions(start_offset, end_offset)
+    # Return the serialized partition values.
+    write_int(len(partitions), outfile)
+    for partition in partitions:
+        pickleSer._write_with_length(partition, outfile)
+
+
+def commit_func(reader, infile, outfile):
+    end_offset = json.loads(utf8_deserializer.loads(infile))
+    reader.commit(end_offset)
+    write_int(0, outfile)
+
+
+def main(infile: IO, outfile: IO) -> None:
+    try:
+        check_python_version(infile)
+        setup_spark_files(infile)
+
+        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", 
"-1"))
+        setup_memory_limits(memory_limit_mb)
+
+        _accumulatorRegistry.clear()
+
+        # Receive the data source instance.
+        data_source = read_command(pickleSer, infile)
+
+        if not isinstance(data_source, DataSource):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "a Python data source instance of type 
'DataSource'",
+                    "actual": f"'{type(data_source).__name__}'",
+                },
+            )
+
+        # Receive the data source output schema.
+        schema_json = utf8_deserializer.loads(infile)
+        schema = _parse_datatype_json_string(schema_json)
+        if not isinstance(schema, StructType):
+            raise PySparkAssertionError(
+                error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+                message_parameters={
+                    "expected": "an output schema of type 'StructType'",
+                    "actual": f"'{type(schema).__name__}'",
+                },
+            )
+
+        # Receive the configuration values.
+        max_arrow_batch_size = read_int(infile)
+        assert max_arrow_batch_size > 0, (
+            "The maximum arrow batch size should be greater than 0, but got "
+            f"'{max_arrow_batch_size}'"
+        )
+
+        # Instantiate data source reader.
+        try:
+            reader = data_source.streamReader(schema=schema)
+            # Initialization succeed.
+            write_int(0, outfile)
+            outfile.flush()
+
+            # handle method call from socket
+            while True:
+                func_id = read_int(infile)
+                if func_id == initial_offset_func_id:
+                    initial_offset_func(reader, outfile)
+                elif func_id == latest_offset_func_id:
+                    latest_offset_func(reader, outfile)
+                elif func_id == partitions_func_id:
+                    partitions_func(reader, infile, outfile)
+                elif func_id == commit_func_id:
+                    commit_func(reader, infile, outfile)
+                outfile.flush()
+
+        except NotImplementedError:
+            raise PySparkRuntimeError(
+                error_class="PYTHON_DATA_SOURCE_METHOD_NOT_IMPLEMENTED",
+                message_parameters={"type": "reader", "method": "reader"},

Review Comment:
   Shouldn't we provide the actual method name to be missed?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.api.python.{PythonFunction, PythonWorker, 
PythonWorkerFactory, PythonWorkerUtils, SpecialLengths}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, 
PYTHON_USE_DAEMON}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+object PythonStreamingSourceRunner {
+  val initialOffsetFuncId = 884
+  val latestOffsetFuncId = 885
+  val partitionsFuncId = 886
+  val commitFuncId = 887
+}
+
+/**
+ * This class is a proxy to invoke methods in Python DataSourceStreamReader 
from JVM.
+ * A runner spawns a python worker process. In the main function, set up 
communication
+ * between JVM and python process through socket and create a 
DataSourceStreamReader instance.
+ * In an infinite loop, the python worker process poll information(function 
name and parameters)
+ * from the socket, invoke the corresponding method of StreamReader and send 
return value to JVM.
+ */
+class PythonStreamingSourceRunner(
+    func: PythonFunction,
+    outputSchema: StructType) extends Logging  {
+  val workerModule = "pyspark.sql.streaming.python_streaming_source_runner"
+
+  private val conf = SparkEnv.get.conf
+  protected val bufferSize: Int = conf.get(BUFFER_SIZE)
+  protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+
+  private val envVars: java.util.Map[String, String] = func.envVars
+  private val pythonExec: String = func.pythonExec
+  private var pythonWorker: Option[PythonWorker] = None
+  private var pythonWorkerFactory: Option[PythonWorkerFactory] = None
+  protected val pythonVer: String = func.pythonVer
+
+  private var dataOut: DataOutputStream = null
+  private var dataIn: DataInputStream = null
+
+  import PythonStreamingSourceRunner._
+
+  /**
+   * Initializes the Python worker for running the streaming source.
+   */
+  def init(): Unit = {
+    logInfo(s"Initializing Python runner pythonExec: $pythonExec")
+    val env = SparkEnv.get
+
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
+    envVars.put("SPARK_LOCAL_DIRS", localdir)
+
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    val prevConf = conf.get(PYTHON_USE_DAEMON)
+    conf.set(PYTHON_USE_DAEMON, false)
+    try {
+      val workerFactory =
+        new PythonWorkerFactory(pythonExec, workerModule, 
envVars.asScala.toMap)
+      val (worker: PythonWorker, _) = 
workerFactory.createSimpleWorker(blockingMode = true)
+      pythonWorker = Some(worker)
+      pythonWorkerFactory = Some(workerFactory)
+    } finally {
+      conf.set(PYTHON_USE_DAEMON, prevConf)
+    }
+
+    val stream = new BufferedOutputStream(
+      pythonWorker.get.channel.socket().getOutputStream, bufferSize)
+    dataOut = new DataOutputStream(stream)
+
+    PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
+
+    val pythonIncludes = func.pythonIncludes.asScala.toSet
+    PythonWorkerUtils.writeSparkFiles(Some("streaming_job"), pythonIncludes, 
dataOut)
+
+    // Send the user function to python process
+    PythonWorkerUtils.writePythonFunction(func, dataOut)
+
+    // Send output schema
+    PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
+
+    // Send configurations
+    dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch)
+    dataOut.flush()
+
+    dataIn = new DataInputStream(
+      new 
BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, 
bufferSize))
+
+    val initStatus = dataIn.readInt()
+    if (initStatus == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryCompilationErrors.pythonDataSourceError(
+        action = "plan", tpe = "initialize source", msg = msg)
+    }
+  }
+
+  /**
+   * Invokes latestOffset() function of the stream reader and receive the 
return value.
+   */
+  def latestOffset(): String = {
+    dataOut.writeInt(latestOffsetFuncId)
+    dataOut.flush()
+    val len = dataIn.readInt()
+    if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(

Review Comment:
   Shall we check whether this error message is enough to identify which source 
has an issue with among multiple sources?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala:
##########
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2.python
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connector.read.{InputPartition, 
PartitionReaderFactory}
+import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
+import org.apache.spark.sql.execution.python.PythonStreamingSourceRunner
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+case class PythonStreamingSourceOffset(json: String) extends Offset
+
+case class PythonStreamingSourcePartition(partition: Array[Byte]) extends 
InputPartition
+
+class PythonMicroBatchStream(
+    ds: PythonDataSourceV2,
+    shortName: String,
+    outputSchema: StructType,
+    options: CaseInsensitiveStringMap
+  ) extends MicroBatchStream with Logging {
+  private def createDataSourceFunc =
+    ds.source.createPythonFunction(
+      ds.getOrCreateDataSourceInPython(shortName, options, 
Some(outputSchema)).dataSource)
+
+  val runner: PythonStreamingSourceRunner =
+    new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema)
+  runner.init()
+
+  override def initialOffset(): Offset = 
PythonStreamingSourceOffset(runner.initialOffset())
+
+  override def latestOffset(): Offset = 
PythonStreamingSourceOffset(runner.latestOffset())
+
+
+  override def planInputPartitions(start: Offset, end: Offset): 
Array[InputPartition] = {
+    runner.partitions(start.asInstanceOf[PythonStreamingSourceOffset].json,
+      
end.asInstanceOf[PythonStreamingSourceOffset].json).map(PythonStreamingSourcePartition(_))
+  }
+
+  override def createReaderFactory(): PartitionReaderFactory = {
+    // TODO: fill in the implementation.

Review Comment:
   Let's file a JIRA ticket and put the ticket number.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.api.python.{PythonFunction, PythonWorker, 
PythonWorkerFactory, PythonWorkerUtils, SpecialLengths}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, 
PYTHON_USE_DAEMON}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+object PythonStreamingSourceRunner {
+  val initialOffsetFuncId = 884
+  val latestOffsetFuncId = 885
+  val partitionsFuncId = 886
+  val commitFuncId = 887
+}
+
+/**
+ * This class is a proxy to invoke methods in Python DataSourceStreamReader 
from JVM.
+ * A runner spawns a python worker process. In the main function, set up 
communication
+ * between JVM and python process through socket and create a 
DataSourceStreamReader instance.
+ * In an infinite loop, the python worker process poll information(function 
name and parameters)
+ * from the socket, invoke the corresponding method of StreamReader and send 
return value to JVM.
+ */
+class PythonStreamingSourceRunner(

Review Comment:
   All protected fields in this class: do they have to be protected? Change to 
private unless needed.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.api.python.{PythonFunction, PythonWorker, 
PythonWorkerFactory, PythonWorkerUtils, SpecialLengths}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, 
PYTHON_USE_DAEMON}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+object PythonStreamingSourceRunner {
+  val initialOffsetFuncId = 884
+  val latestOffsetFuncId = 885
+  val partitionsFuncId = 886
+  val commitFuncId = 887
+}
+
+/**
+ * This class is a proxy to invoke methods in Python DataSourceStreamReader 
from JVM.
+ * A runner spawns a python worker process. In the main function, set up 
communication
+ * between JVM and python process through socket and create a 
DataSourceStreamReader instance.
+ * In an infinite loop, the python worker process poll information(function 
name and parameters)
+ * from the socket, invoke the corresponding method of StreamReader and send 
return value to JVM.
+ */
+class PythonStreamingSourceRunner(
+    func: PythonFunction,
+    outputSchema: StructType) extends Logging  {
+  val workerModule = "pyspark.sql.streaming.python_streaming_source_runner"
+
+  private val conf = SparkEnv.get.conf
+  protected val bufferSize: Int = conf.get(BUFFER_SIZE)
+  protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+
+  private val envVars: java.util.Map[String, String] = func.envVars
+  private val pythonExec: String = func.pythonExec
+  private var pythonWorker: Option[PythonWorker] = None
+  private var pythonWorkerFactory: Option[PythonWorkerFactory] = None
+  protected val pythonVer: String = func.pythonVer
+
+  private var dataOut: DataOutputStream = null
+  private var dataIn: DataInputStream = null
+
+  import PythonStreamingSourceRunner._
+
+  /**
+   * Initializes the Python worker for running the streaming source.
+   */
+  def init(): Unit = {
+    logInfo(s"Initializing Python runner pythonExec: $pythonExec")
+    val env = SparkEnv.get
+
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
+    envVars.put("SPARK_LOCAL_DIRS", localdir)
+
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    val prevConf = conf.get(PYTHON_USE_DAEMON)
+    conf.set(PYTHON_USE_DAEMON, false)
+    try {
+      val workerFactory =
+        new PythonWorkerFactory(pythonExec, workerModule, 
envVars.asScala.toMap)
+      val (worker: PythonWorker, _) = 
workerFactory.createSimpleWorker(blockingMode = true)
+      pythonWorker = Some(worker)
+      pythonWorkerFactory = Some(workerFactory)
+    } finally {
+      conf.set(PYTHON_USE_DAEMON, prevConf)
+    }
+
+    val stream = new BufferedOutputStream(
+      pythonWorker.get.channel.socket().getOutputStream, bufferSize)
+    dataOut = new DataOutputStream(stream)
+
+    PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
+
+    val pythonIncludes = func.pythonIncludes.asScala.toSet
+    PythonWorkerUtils.writeSparkFiles(Some("streaming_job"), pythonIncludes, 
dataOut)
+
+    // Send the user function to python process
+    PythonWorkerUtils.writePythonFunction(func, dataOut)
+
+    // Send output schema
+    PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
+
+    // Send configurations
+    dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch)
+    dataOut.flush()
+
+    dataIn = new DataInputStream(
+      new 
BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, 
bufferSize))
+
+    val initStatus = dataIn.readInt()
+    if (initStatus == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryCompilationErrors.pythonDataSourceError(
+        action = "plan", tpe = "initialize source", msg = msg)
+    }
+  }
+
+  /**
+   * Invokes latestOffset() function of the stream reader and receive the 
return value.
+   */
+  def latestOffset(): String = {
+    dataOut.writeInt(latestOffsetFuncId)
+    dataOut.flush()
+    val len = dataIn.readInt()
+    if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "latestOffset", msg)
+    }
+    PythonWorkerUtils.readUTF(len, dataIn)
+  }
+
+  /**
+   * Invokes initialOffset() function of the stream reader and receive the 
return value.
+   */
+  def initialOffset(): String = {
+    dataOut.writeInt(initialOffsetFuncId)
+    dataOut.flush()
+    val len = dataIn.readInt()
+    if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.api.python.{PythonFunction, PythonWorker, 
PythonWorkerFactory, PythonWorkerUtils, SpecialLengths}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, 
PYTHON_USE_DAEMON}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+object PythonStreamingSourceRunner {
+  val initialOffsetFuncId = 884
+  val latestOffsetFuncId = 885
+  val partitionsFuncId = 886
+  val commitFuncId = 887
+}
+
+/**
+ * This class is a proxy to invoke methods in Python DataSourceStreamReader 
from JVM.
+ * A runner spawns a python worker process. In the main function, set up 
communication
+ * between JVM and python process through socket and create a 
DataSourceStreamReader instance.
+ * In an infinite loop, the python worker process poll information(function 
name and parameters)
+ * from the socket, invoke the corresponding method of StreamReader and send 
return value to JVM.
+ */
+class PythonStreamingSourceRunner(
+    func: PythonFunction,
+    outputSchema: StructType) extends Logging  {
+  val workerModule = "pyspark.sql.streaming.python_streaming_source_runner"

Review Comment:
   Could you please add some guide comments for reviewers e.g. if you copied a 
block of code, where you copied the code from? It would greatly help to avoid 
looking into all details for "existing and already reviewed" code.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.api.python.{PythonFunction, PythonWorker, 
PythonWorkerFactory, PythonWorkerUtils, SpecialLengths}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, 
PYTHON_USE_DAEMON}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+object PythonStreamingSourceRunner {
+  val initialOffsetFuncId = 884
+  val latestOffsetFuncId = 885
+  val partitionsFuncId = 886
+  val commitFuncId = 887
+}
+
+/**
+ * This class is a proxy to invoke methods in Python DataSourceStreamReader 
from JVM.
+ * A runner spawns a python worker process. In the main function, set up 
communication
+ * between JVM and python process through socket and create a 
DataSourceStreamReader instance.
+ * In an infinite loop, the python worker process poll information(function 
name and parameters)
+ * from the socket, invoke the corresponding method of StreamReader and send 
return value to JVM.
+ */
+class PythonStreamingSourceRunner(
+    func: PythonFunction,
+    outputSchema: StructType) extends Logging  {
+  val workerModule = "pyspark.sql.streaming.python_streaming_source_runner"
+
+  private val conf = SparkEnv.get.conf
+  protected val bufferSize: Int = conf.get(BUFFER_SIZE)
+  protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+
+  private val envVars: java.util.Map[String, String] = func.envVars
+  private val pythonExec: String = func.pythonExec
+  private var pythonWorker: Option[PythonWorker] = None
+  private var pythonWorkerFactory: Option[PythonWorkerFactory] = None
+  protected val pythonVer: String = func.pythonVer
+
+  private var dataOut: DataOutputStream = null
+  private var dataIn: DataInputStream = null
+
+  import PythonStreamingSourceRunner._
+
+  /**
+   * Initializes the Python worker for running the streaming source.
+   */
+  def init(): Unit = {
+    logInfo(s"Initializing Python runner pythonExec: $pythonExec")
+    val env = SparkEnv.get
+
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
+    envVars.put("SPARK_LOCAL_DIRS", localdir)
+
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    val prevConf = conf.get(PYTHON_USE_DAEMON)
+    conf.set(PYTHON_USE_DAEMON, false)
+    try {
+      val workerFactory =
+        new PythonWorkerFactory(pythonExec, workerModule, 
envVars.asScala.toMap)
+      val (worker: PythonWorker, _) = 
workerFactory.createSimpleWorker(blockingMode = true)
+      pythonWorker = Some(worker)
+      pythonWorkerFactory = Some(workerFactory)
+    } finally {
+      conf.set(PYTHON_USE_DAEMON, prevConf)
+    }
+
+    val stream = new BufferedOutputStream(
+      pythonWorker.get.channel.socket().getOutputStream, bufferSize)
+    dataOut = new DataOutputStream(stream)
+
+    PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
+
+    val pythonIncludes = func.pythonIncludes.asScala.toSet
+    PythonWorkerUtils.writeSparkFiles(Some("streaming_job"), pythonIncludes, 
dataOut)
+
+    // Send the user function to python process
+    PythonWorkerUtils.writePythonFunction(func, dataOut)
+
+    // Send output schema
+    PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
+
+    // Send configurations
+    dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch)
+    dataOut.flush()
+
+    dataIn = new DataInputStream(
+      new 
BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, 
bufferSize))
+
+    val initStatus = dataIn.readInt()
+    if (initStatus == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryCompilationErrors.pythonDataSourceError(
+        action = "plan", tpe = "initialize source", msg = msg)
+    }
+  }
+
+  /**
+   * Invokes latestOffset() function of the stream reader and receive the 
return value.
+   */
+  def latestOffset(): String = {
+    dataOut.writeInt(latestOffsetFuncId)
+    dataOut.flush()
+    val len = dataIn.readInt()
+    if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "latestOffset", msg)
+    }
+    PythonWorkerUtils.readUTF(len, dataIn)
+  }
+
+  /**
+   * Invokes initialOffset() function of the stream reader and receive the 
return value.
+   */
+  def initialOffset(): String = {
+    dataOut.writeInt(initialOffsetFuncId)
+    dataOut.flush()
+    val len = dataIn.readInt()
+    if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "initialOffset", msg)
+    }
+    PythonWorkerUtils.readUTF(len, dataIn)
+  }
+
+  /**
+   * Invokes partitions(start, end) function of the stream reader and receive 
the return value.
+   */
+  def partitions(start: String, end: String): Array[Array[Byte]] = {
+    dataOut.writeInt(partitionsFuncId)
+    PythonWorkerUtils.writeUTF(start, dataOut)
+    PythonWorkerUtils.writeUTF(end, dataOut)
+    dataOut.flush()
+    // Receive the list of partitions, if any.
+    val pickledPartitions = ArrayBuffer.empty[Array[Byte]]
+    val numPartitions = dataIn.readInt()
+    if (numPartitions == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(

Review Comment:
   ditto



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala:
##########
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.api.python.{PythonFunction, PythonWorker, 
PythonWorkerFactory, PythonWorkerUtils, SpecialLengths}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, 
PYTHON_USE_DAEMON}
+import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+object PythonStreamingSourceRunner {
+  val initialOffsetFuncId = 884
+  val latestOffsetFuncId = 885
+  val partitionsFuncId = 886
+  val commitFuncId = 887
+}
+
+/**
+ * This class is a proxy to invoke methods in Python DataSourceStreamReader 
from JVM.
+ * A runner spawns a python worker process. In the main function, set up 
communication
+ * between JVM and python process through socket and create a 
DataSourceStreamReader instance.
+ * In an infinite loop, the python worker process poll information(function 
name and parameters)
+ * from the socket, invoke the corresponding method of StreamReader and send 
return value to JVM.
+ */
+class PythonStreamingSourceRunner(
+    func: PythonFunction,
+    outputSchema: StructType) extends Logging  {
+  val workerModule = "pyspark.sql.streaming.python_streaming_source_runner"
+
+  private val conf = SparkEnv.get.conf
+  protected val bufferSize: Int = conf.get(BUFFER_SIZE)
+  protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+
+  private val envVars: java.util.Map[String, String] = func.envVars
+  private val pythonExec: String = func.pythonExec
+  private var pythonWorker: Option[PythonWorker] = None
+  private var pythonWorkerFactory: Option[PythonWorkerFactory] = None
+  protected val pythonVer: String = func.pythonVer
+
+  private var dataOut: DataOutputStream = null
+  private var dataIn: DataInputStream = null
+
+  import PythonStreamingSourceRunner._
+
+  /**
+   * Initializes the Python worker for running the streaming source.
+   */
+  def init(): Unit = {
+    logInfo(s"Initializing Python runner pythonExec: $pythonExec")
+    val env = SparkEnv.get
+
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
+    envVars.put("SPARK_LOCAL_DIRS", localdir)
+
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    val prevConf = conf.get(PYTHON_USE_DAEMON)
+    conf.set(PYTHON_USE_DAEMON, false)
+    try {
+      val workerFactory =
+        new PythonWorkerFactory(pythonExec, workerModule, 
envVars.asScala.toMap)
+      val (worker: PythonWorker, _) = 
workerFactory.createSimpleWorker(blockingMode = true)
+      pythonWorker = Some(worker)
+      pythonWorkerFactory = Some(workerFactory)
+    } finally {
+      conf.set(PYTHON_USE_DAEMON, prevConf)
+    }
+
+    val stream = new BufferedOutputStream(
+      pythonWorker.get.channel.socket().getOutputStream, bufferSize)
+    dataOut = new DataOutputStream(stream)
+
+    PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
+
+    val pythonIncludes = func.pythonIncludes.asScala.toSet
+    PythonWorkerUtils.writeSparkFiles(Some("streaming_job"), pythonIncludes, 
dataOut)
+
+    // Send the user function to python process
+    PythonWorkerUtils.writePythonFunction(func, dataOut)
+
+    // Send output schema
+    PythonWorkerUtils.writeUTF(outputSchema.json, dataOut)
+
+    // Send configurations
+    dataOut.writeInt(SQLConf.get.arrowMaxRecordsPerBatch)
+    dataOut.flush()
+
+    dataIn = new DataInputStream(
+      new 
BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, 
bufferSize))
+
+    val initStatus = dataIn.readInt()
+    if (initStatus == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryCompilationErrors.pythonDataSourceError(
+        action = "plan", tpe = "initialize source", msg = msg)
+    }
+  }
+
+  /**
+   * Invokes latestOffset() function of the stream reader and receive the 
return value.
+   */
+  def latestOffset(): String = {
+    dataOut.writeInt(latestOffsetFuncId)
+    dataOut.flush()
+    val len = dataIn.readInt()
+    if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "latestOffset", msg)
+    }
+    PythonWorkerUtils.readUTF(len, dataIn)
+  }
+
+  /**
+   * Invokes initialOffset() function of the stream reader and receive the 
return value.
+   */
+  def initialOffset(): String = {
+    dataOut.writeInt(initialOffsetFuncId)
+    dataOut.flush()
+    val len = dataIn.readInt()
+    if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "initialOffset", msg)
+    }
+    PythonWorkerUtils.readUTF(len, dataIn)
+  }
+
+  /**
+   * Invokes partitions(start, end) function of the stream reader and receive 
the return value.
+   */
+  def partitions(start: String, end: String): Array[Array[Byte]] = {
+    dataOut.writeInt(partitionsFuncId)
+    PythonWorkerUtils.writeUTF(start, dataOut)
+    PythonWorkerUtils.writeUTF(end, dataOut)
+    dataOut.flush()
+    // Receive the list of partitions, if any.
+    val pickledPartitions = ArrayBuffer.empty[Array[Byte]]
+    val numPartitions = dataIn.readInt()
+    if (numPartitions == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "planPartitions", msg)
+    }
+    for (_ <- 0 until numPartitions) {
+      val pickledPartition: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
+      pickledPartitions.append(pickledPartition)
+    }
+    pickledPartitions.toArray
+  }
+
+  /**
+   * Invokes commit(end) function of the stream reader and receive the return 
value.
+   */
+  def commit(end: String): Unit = {
+    dataOut.writeInt(commitFuncId)
+    PythonWorkerUtils.writeUTF(end, dataOut)
+    dataOut.flush()
+    val status = dataIn.readInt()
+    if (status == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(

Review Comment:
   ditto



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala:
##########
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.AnalysisException
+import 
org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, 
shouldTestPandasUDFs}
+import 
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, 
PythonMicroBatchStream, PythonStreamingSourceOffset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase {
+
+  protected def simpleDataStreamReaderScript: String =
+    """
+      |from pyspark.sql.datasource import DataSourceStreamReader, 
InputPartition
+      |
+      |class SimpleDataStreamReader(DataSourceStreamReader):
+      |    def initialOffset(self):
+      |        return {"0": "2"}
+      |    def latestOffset(self):
+      |        return {"0": "2"}

Review Comment:
   why not use different offset to distinguish?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to