This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch 
WIP-python-data-source-admission-control-trigger-availablenow-change-the-method-signature
in repository https://gitbox.apache.org/repos/asf/spark.git

commit 9b0a9b34b071bc87590941a6d4c9cd7c339bc0eb
Author: Jungtaek Lim <[email protected]>
AuthorDate: Mon Jan 19 14:58:48 2026 +0900

    WIP python data source Trigger.AvailableNow
---
 python/pyspark/sql/datasource.py                   |  55 ++++++++++
 python/pyspark/sql/datasource_internal.py          |  21 +++-
 .../streaming/python_streaming_source_runner.py    |  70 ++++++++++++-
 .../v2/python/PythonMicroBatchStream.scala         |  64 ++++++++++--
 .../datasources/v2/python/PythonScan.scala         |  19 +++-
 .../streaming/PythonStreamingSourceRunner.scala    |  58 +++++++++++
 .../streaming/PythonStreamingDataSourceSuite.scala | 116 +++++++++++++++++----
 7 files changed, 369 insertions(+), 34 deletions(-)

diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index f1908180a3ba..854f67217acf 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -909,6 +909,61 @@ class SimpleDataSourceStreamReader(ABC):
         ...
 
 
+class ReadLimit(ABC):
+    pass
+
+
+class ReadAllAvailable(ReadLimit):
+    def __init__(self):
+        pass
+
+
+class SupportsAdmissionControl(ABC):
+    @abstractmethod
+    def latestOffset(self, start: dict, readLimit: ReadLimit) -> dict:
+        """
+        FIXME: docstring needed
+
+      /**
+       * Returns the most recent offset available given a read limit. The 
start offset can be used
+       * to figure out how much new data should be read given the limit. Users 
should implement this
+       * method instead of latestOffset for a MicroBatchStream or getOffset 
for Source.
+       * <p>
+       * When this method is called on a `Source`, the source can return 
`null` if there is no
+       * data to process. In addition, for the very first micro-batch, the 
`startOffset` will be
+       * null as well.
+       * <p>
+       * When this method is called on a MicroBatchStream, the `startOffset` 
will be `initialOffset`
+       * for the very first micro-batch. The source can return `null` if there 
is no data to process.
+       */
+        """
+        pass
+
+
+class SupportsTriggerAvailableNow(ABC):
+    @abstractmethod
+    def prepareForTriggerAvailableNow(self) -> None:
+        """
+        FIXME: docstring needed
+
+        /**
+         * This will be called at the beginning of streaming queries with 
Trigger.AvailableNow, to let the
+         * source record the offset for the current latest data at the time 
(a.k.a the target offset for
+         * the query). The source will behave as if there is no new data 
coming in after the target
+         * offset, i.e., the source will not return an offset higher than the 
target offset when
+         * {@link #latestOffset(Offset, ReadLimit) latestOffset} is called.
+         * <p>
+         * Note that there is an exception on the first uncommitted batch 
after a restart, where the end
+         * offset is not derived from the current latest offset. Sources need 
to take special
+         * considerations if wanting to assert such relation. One possible way 
is to have an internal
+         * flag in the source to indicate whether it is Trigger.AvailableNow, 
set the flag in this method,
+         * and record the target offset in the first call of
+         * {@link #latestOffset(Offset, ReadLimit) latestOffset}.
+         */
+        """
+        pass
+
+
 class DataSourceWriter(ABC):
     """
     A base class for data source writers. Data source writers are responsible 
for saving
diff --git a/python/pyspark/sql/datasource_internal.py 
b/python/pyspark/sql/datasource_internal.py
index 6df0be4192ec..8b0fa5eb0a5f 100644
--- a/python/pyspark/sql/datasource_internal.py
+++ b/python/pyspark/sql/datasource_internal.py
@@ -25,7 +25,10 @@ from pyspark.sql.datasource import (
     DataSource,
     DataSourceStreamReader,
     InputPartition,
+    ReadAllAvailable,
+    ReadLimit,
     SimpleDataSourceStreamReader,
+    SupportsAdmissionControl,
 )
 from pyspark.sql.types import StructType
 from pyspark.errors import PySparkNotImplementedError
@@ -56,7 +59,7 @@ class PrefetchedCacheEntry:
         self.iterator = iterator
 
 
-class _SimpleStreamReaderWrapper(DataSourceStreamReader):
+class _SimpleStreamReaderWrapper(DataSourceStreamReader, 
SupportsAdmissionControl):
     """
     A private class that wrap :class:`SimpleDataSourceStreamReader` in 
prefetch and cache pattern,
     so that :class:`SimpleDataSourceStreamReader` can integrate with streaming 
engine like an
@@ -97,6 +100,22 @@ class _SimpleStreamReaderWrapper(DataSourceStreamReader):
         self.current_offset = end
         return end
 
+    def latestOffset(self, start: dict, readLimit: ReadLimit) -> dict:
+        if self.current_offset is None:
+            assert start != None, "start offset should not be None"
+            self.current_offset = start
+        else:
+            assert self.current_offset == start, ("start offset does not match 
current offset. "
+                   f"current: {self.current_offset}, start: {start}")
+
+        assert isinstance(readLimit, ReadAllAvailable), ("simple stream reader 
does not "
+                                                         "support read limit")
+
+        (iter, end) = self.simple_reader.read(self.current_offset)
+        self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
+        self.current_offset = end
+        return end
+
     def commit(self, end: dict) -> None:
         if self.current_offset is None:
             self.current_offset = end
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py 
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index ab988eb714cc..54bf12843232 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -28,7 +28,13 @@ from pyspark.serializers import (
     write_with_length,
     SpecialLengths,
 )
-from pyspark.sql.datasource import DataSource, DataSourceStreamReader
+from pyspark.sql.datasource import (
+    DataSource,
+    DataSourceStreamReader,
+    ReadAllAvailable,
+    SupportsAdmissionControl,
+    SupportsTriggerAvailableNow
+)
 from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper, 
_streamReader
 from pyspark.sql.pandas.serializers import ArrowStreamSerializer
 from pyspark.sql.types import (
@@ -51,11 +57,17 @@ INITIAL_OFFSET_FUNC_ID = 884
 LATEST_OFFSET_FUNC_ID = 885
 PARTITIONS_FUNC_ID = 886
 COMMIT_FUNC_ID = 887
+CHECK_SUPPORTED_FEATURES_ID = 888
+PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID = 889
+LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID = 890
 
 PREFETCHED_RECORDS_NOT_FOUND = 0
 NON_EMPTY_PYARROW_RECORD_BATCHES = 1
 EMPTY_PYARROW_RECORD_BATCHES = 2
 
+SUPPORTS_ADMISSION_CONTROL = 1
+SUPPORTS_TRIGGER_AVAILABLE_NOW = 1 << 1
+
 
 def initial_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None:
     offset = reader.initialOffset()
@@ -116,6 +128,56 @@ def send_batch_func(
         write_int(EMPTY_PYARROW_RECORD_BATCHES, outfile)
 
 
+def check_support_func(reader, outfile):
+    support_flags = 0
+    if isinstance(reader, _SimpleStreamReaderWrapper):
+        # We consider the method of `read` in simple_reader to already have 
admission control
+        # into it.
+        support_flags |= SUPPORTS_TRIGGER_AVAILABLE_NOW
+        if isinstance(reader.simple_reader, SupportsTriggerAvailableNow):
+            support_flags |= SUPPORTS_TRIGGER_AVAILABLE_NOW
+    else:
+        if isinstance(reader, SupportsAdmissionControl):
+            support_flags |= SUPPORTS_ADMISSION_CONTROL
+        if isinstance(reader, SupportsTriggerAvailableNow):
+            support_flags |= SUPPORTS_TRIGGER_AVAILABLE_NOW
+    write_int(support_flags, outfile)
+
+
+def prepare_for_trigger_available_now_func(reader, outfile):
+    if isinstance(reader, _SimpleStreamReaderWrapper):
+        if isinstance(reader.simple_reader, SupportsTriggerAvailableNow):
+            reader.simple_reader.prepareForTriggerAvailableNow()
+        else:
+            # FIXME: code for not supported? or should it be assertion?
+            raise Exception("prepareForTriggerAvailableNow is not supported by 
the "
+                            "underlying simple reader.")
+    else:
+        if isinstance(reader, SupportsTriggerAvailableNow):
+            reader.prepareForTriggerAvailableNow()
+        else:
+            # FIXME: code for not supported? or should it be assertion?
+            raise Exception("prepareForTriggerAvailableNow is not supported by 
the "
+                            "stream reader.")
+    write_int(0, outfile)
+
+
+def latest_offset_admission_control_func(reader, infile, outfile):
+    start_offset_dict = json.loads(utf8_deserializer.loads(infile))
+
+    limit_type = read_int(infile)
+    if limit_type == 0:
+        # ReadAllAvailable
+        limit = ReadAllAvailable()
+    else:
+        # FIXME: raise error
+        # FIXME: code for not supported?
+        raise Exception("Only ReadAllAvailable is supported for 
latestOffsetAdmissionControl.")
+
+    offset = reader.latestOffset(start_offset_dict, limit)
+    write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+
+
 def main(infile: IO, outfile: IO) -> None:
     try:
         check_python_version(infile)
@@ -176,6 +238,12 @@ def main(infile: IO, outfile: IO) -> None:
                     )
                 elif func_id == COMMIT_FUNC_ID:
                     commit_func(reader, infile, outfile)
+                elif func_id == CHECK_SUPPORTED_FEATURES_ID:
+                    check_support_func(reader, outfile)
+                elif func_id == PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID:
+                    prepare_for_trigger_available_now_func(reader, outfile)
+                elif func_id == LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID:
+                    latest_offset_admission_control_func(reader, infile, 
outfile)
                 else:
                     raise IllegalArgumentException(
                         errorClass="UNSUPPORTED_OPERATION",
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
index 50ea7616061c..e4fef9e6763c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
@@ -17,9 +17,10 @@
 package org.apache.spark.sql.execution.datasources.v2.python
 
 import org.apache.spark.SparkEnv
+import org.apache.spark.api.python.PythonFunction
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connector.read.{InputPartition, 
PartitionReaderFactory}
-import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, 
MicroBatchStream, Offset}
+import org.apache.spark.sql.connector.read.streaming.{AcceptsLatestSeenOffset, 
MicroBatchStream, Offset, ReadLimit, SupportsAdmissionControl, 
SupportsTriggerAvailableNow}
 import 
org.apache.spark.sql.execution.datasources.v2.python.PythonMicroBatchStream.nextStreamId
 import 
org.apache.spark.sql.execution.python.streaming.PythonStreamingSourceRunner
 import org.apache.spark.sql.types.StructType
@@ -32,14 +33,12 @@ class PythonMicroBatchStream(
     ds: PythonDataSourceV2,
     shortName: String,
     outputSchema: StructType,
-    options: CaseInsensitiveStringMap
+    options: CaseInsensitiveStringMap,
+    runner: PythonStreamingSourceRunner
   )
   extends MicroBatchStream
   with Logging
   with AcceptsLatestSeenOffset {
-  private def createDataSourceFunc =
-    ds.source.createPythonFunction(
-      ds.getOrCreateDataSourceInPython(shortName, options, 
Some(outputSchema)).dataSource)
 
   private val streamId = nextStreamId
   private var nextBlockId = 0L
@@ -49,10 +48,6 @@ class PythonMicroBatchStream(
   // from python to JVM.
   private var cachedInputPartition: Option[(String, String, 
PythonStreamingInputPartition)] = None
 
-  private val runner: PythonStreamingSourceRunner =
-    new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema)
-  runner.init()
-
   override def initialOffset(): Offset = 
PythonStreamingSourceOffset(runner.initialOffset())
 
   override def latestOffset(): Offset = 
PythonStreamingSourceOffset(runner.latestOffset())
@@ -110,10 +105,61 @@ class PythonMicroBatchStream(
   override def deserializeOffset(json: String): Offset = 
PythonStreamingSourceOffset(json)
 }
 
+class PythonMicroBatchStreamWithAdmissionControl(
+    ds: PythonDataSourceV2,
+    shortName: String,
+    outputSchema: StructType,
+    options: CaseInsensitiveStringMap,
+    runner: PythonStreamingSourceRunner)
+  extends PythonMicroBatchStream(ds, shortName, outputSchema, options, runner)
+  with SupportsAdmissionControl {
+
+  override def latestOffset(): Offset = {
+    throw new IllegalStateException("latestOffset without parameters is not 
expected to be " +
+      "called. Please use latestOffset(startOffset: Offset, limit: ReadLimit) 
instead.")
+  }
+
+  override def latestOffset(startOffset: Offset, limit: ReadLimit): Offset = {
+    PythonStreamingSourceOffset(runner.latestOffset(startOffset, limit))
+  }
+}
+
+class PythonMicroBatchStreamWithTriggerAvailableNow(
+    ds: PythonDataSourceV2,
+    shortName: String,
+    outputSchema: StructType,
+    options: CaseInsensitiveStringMap,
+    runner: PythonStreamingSourceRunner)
+  extends PythonMicroBatchStreamWithAdmissionControl(ds, shortName, 
outputSchema, options, runner)
+  with SupportsTriggerAvailableNow {
+
+  override def prepareForTriggerAvailableNow(): Unit = {
+    runner.prepareForTriggerAvailableNow()
+  }
+}
+
 object PythonMicroBatchStream {
   private var currentId = 0
   def nextStreamId: Int = synchronized {
     currentId = currentId + 1
     currentId
   }
+
+  def createPythonStreamingSourceRunner(
+      ds: PythonDataSourceV2,
+      shortName: String,
+      outputSchema: StructType,
+      options: CaseInsensitiveStringMap): PythonStreamingSourceRunner = {
+
+    // Below methods were called during the construction of 
PythonMicroBatchStream, so there is no
+    // timing/sequencing issue of calling them in here.
+    def createDataSourceFunc: PythonFunction =
+      ds.source.createPythonFunction(
+        ds.getOrCreateDataSourceInPython(
+          shortName,
+          options,
+          Some(outputSchema)).dataSource)
+
+    new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema)
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
index a133c40cde60..9e3effe7d441 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
@@ -35,8 +35,23 @@ class PythonScan(
 ) extends Scan with SupportsMetadata {
   override def toBatch: Batch = new PythonBatch(ds, shortName, outputSchema, 
options)
 
-  override def toMicroBatchStream(checkpointLocation: String): 
MicroBatchStream =
-    new PythonMicroBatchStream(ds, shortName, outputSchema, options)
+  override def toMicroBatchStream(checkpointLocation: String): 
MicroBatchStream = {
+    val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
+      ds, shortName, outputSchema, options)
+    runner.init()
+
+    val supportedFeatures = runner.checkSupportedFeatures()
+
+    if (supportedFeatures.triggerAvailableNow) {
+      new PythonMicroBatchStreamWithTriggerAvailableNow(
+        ds, shortName, outputSchema, options, runner)
+    } else if (supportedFeatures.admissionControl) {
+      new PythonMicroBatchStreamWithAdmissionControl(
+        ds, shortName, outputSchema, options, runner)
+    } else {
+      new PythonMicroBatchStream(ds, shortName, outputSchema, options, runner)
+    }
+  }
 
   override def description: String = "(Python)"
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
index 270d816e9bd9..36e4d09041ef 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
@@ -33,6 +33,7 @@ import org.apache.spark.internal.LogKeys.PYTHON_EXEC
 import org.apache.spark.internal.config.BUFFER_SIZE
 import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.read.streaming.{Offset, 
ReadAllAvailable, ReadLimit}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
@@ -46,11 +47,16 @@ object PythonStreamingSourceRunner {
   val LATEST_OFFSET_FUNC_ID = 885
   val PARTITIONS_FUNC_ID = 886
   val COMMIT_FUNC_ID = 887
+  val CHECK_SUPPORTED_FEATURES_ID = 888
+  val PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID = 889
+  val LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID = 890
   // Status code for JVM to decide how to receive prefetched record batches
   // for simple stream reader.
   val PREFETCHED_RECORDS_NOT_FOUND = 0
   val NON_EMPTY_PYARROW_RECORD_BATCHES = 1
   val EMPTY_PYARROW_RECORD_BATCHES = 2
+
+  case class SupportedFeatures(admissionControl: Boolean, triggerAvailableNow: 
Boolean)
 }
 
 /**
@@ -129,6 +135,34 @@ class PythonStreamingSourceRunner(
     }
   }
 
+  def checkSupportedFeatures(): SupportedFeatures = {
+    dataOut.writeInt(CHECK_SUPPORTED_FEATURES_ID)
+    dataOut.flush()
+
+    val featureBits = dataIn.readInt()
+    if (featureBits == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "checkSupportedFeatures", msg)
+    }
+    val admissionControl = (featureBits | (1 << 0)) == 1
+    val availableNow = (featureBits | (1 << 1)) == (1 << 1)
+
+    SupportedFeatures(admissionControl, availableNow)
+  }
+
+  def prepareForTriggerAvailableNow(): Unit = {
+    dataOut.writeInt(PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID)
+    dataOut.flush()
+    val status = dataIn.readInt()
+    // FIXME: code for not supported?
+    if (status == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "prepareForTriggerAvailableNow", msg)
+    }
+  }
+
   /**
    * Invokes latestOffset() function of the stream reader and receive the 
return value.
    */
@@ -144,6 +178,30 @@ class PythonStreamingSourceRunner(
     PythonWorkerUtils.readUTF(len, dataIn)
   }
 
+  def latestOffset(startOffset: Offset, limit: ReadLimit): String = {
+    dataOut.writeInt(LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID)
+    PythonWorkerUtils.writeUTF(startOffset.json, dataOut)
+    limit match {
+      case _: ReadAllAvailable =>
+        dataOut.writeInt(0)
+        dataOut.flush()
+
+      case _ =>
+        // FIXME: Add support for other ReadLimit types
+        // throw 
QueryExecutionErrors.unsupportedReadLimitTypeError(limit.getClass.getName)
+        throw new UnsupportedOperationException("Unsupported ReadLimit type: " 
+
+          s"${limit.getClass.getName}")
+    }
+    dataOut.flush()
+    val len = dataIn.readInt()
+    if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "latestOffset", msg)
+    }
+    PythonWorkerUtils.readUTF(len, dataIn)
+  }
+
   /**
    * Invokes initialOffset() function of the stream reader and receive the 
return value.
    */
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
index 0e33b6e55a43..330a0513d360 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
@@ -24,7 +24,8 @@ import scala.concurrent.duration._
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
 import 
org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, 
shouldTestPandasUDFs}
-import 
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, 
PythonMicroBatchStream, PythonStreamingSourceOffset}
+import org.apache.spark.sql.connector.read.streaming.ReadLimit
+import 
org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, 
PythonMicroBatchStream, PythonMicroBatchStreamWithAdmissionControl, 
PythonStreamingSourceOffset}
 import org.apache.spark.sql.execution.python.PythonDataSourceSuiteBase
 import org.apache.spark.sql.execution.streaming.ProcessingTimeTrigger
 import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, 
OffsetSeqLog}
@@ -249,12 +250,18 @@ class PythonStreamingDataSourceSimpleSuite extends 
PythonDataSourceSuiteBase {
     pythonDs.setShortName("ErrorDataSource")
 
     def testMicroBatchStreamError(action: String, msg: String)(
-        func: PythonMicroBatchStream => Unit): Unit = {
-      val stream = new PythonMicroBatchStream(
+        func: PythonMicroBatchStreamWithAdmissionControl => Unit): Unit = {
+      val options = CaseInsensitiveStringMap.empty()
+      val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
+        pythonDs, errorDataSourceName, inputSchema, options)
+      runner.init()
+
+      val stream = new PythonMicroBatchStreamWithAdmissionControl(
         pythonDs,
         errorDataSourceName,
         inputSchema,
-        CaseInsensitiveStringMap.empty()
+        options,
+        runner
       )
       val err = intercept[SparkException] {
         func(stream)
@@ -277,16 +284,6 @@ class PythonStreamingDataSourceSimpleSuite extends 
PythonDataSourceSuiteBase {
       stream =>
         stream.initialOffset()
     }
-
-    // User don't need to implement latestOffset for 
SimpleDataSourceStreamReader.
-    // The latestOffset method of simple stream reader invokes initialOffset() 
and read()
-    // So the not implemented method is initialOffset.
-    testMicroBatchStreamError(
-      "latestOffset",
-      "[NOT_IMPLEMENTED] initialOffset is not implemented") {
-      stream =>
-        stream.latestOffset()
-    }
   }
 
   test("read() method throw error in SimpleDataSourceStreamReader") {
@@ -314,12 +311,18 @@ class PythonStreamingDataSourceSimpleSuite extends 
PythonDataSourceSuiteBase {
     pythonDs.setShortName("ErrorDataSource")
 
     def testMicroBatchStreamError(action: String, msg: String)(
-        func: PythonMicroBatchStream => Unit): Unit = {
-      val stream = new PythonMicroBatchStream(
+        func: PythonMicroBatchStreamWithAdmissionControl => Unit): Unit = {
+      val options = CaseInsensitiveStringMap.empty()
+      val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
+        pythonDs, errorDataSourceName, inputSchema, options)
+      runner.init()
+
+      val stream = new PythonMicroBatchStreamWithAdmissionControl(
         pythonDs,
         errorDataSourceName,
         inputSchema,
-        CaseInsensitiveStringMap.empty()
+        options,
+        runner
       )
       val err = intercept[SparkException] {
         func(stream)
@@ -337,7 +340,59 @@ class PythonStreamingDataSourceSimpleSuite extends 
PythonDataSourceSuiteBase {
     }
 
     testMicroBatchStreamError("latestOffset", "Exception: error reading 
available data") { stream =>
-      stream.latestOffset()
+      stream.latestOffset(PythonStreamingSourceOffset("""{"partition": 0}"""),
+        ReadLimit.allAvailable())
+    }
+  }
+
+  test("SimpleDataSourceStreamReader with Trigger.AvailableNow") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import DataSource
+         |from pyspark.sql.datasource import SimpleDataSourceStreamReader, 
SupportsTriggerAvailableNow
+         |
+         |class SimpleDataStreamReader(SimpleDataSourceStreamReader, 
SupportsTriggerAvailableNow):
+         |    def initialOffset(self):
+         |        return {"partition-1": 0}
+         |    def read(self, start: dict):
+         |        start_idx = start["partition-1"]
+         |        end_offset = min(start_idx + 2, self.desired_end_offset)
+         |        it = iter([(i, ) for i in range(start_idx, end_offset)])
+         |        return (it, {"partition-1": end_offset})
+         |    def readBetweenOffsets(self, start: dict, end: dict):
+         |        start_idx = start["partition-1"]
+         |        end_idx = end["partition-1"]
+         |        return iter([(i, ) for i in range(start_idx, end_idx)])
+         |    def prepareForTriggerAvailableNow(self):
+         |        self.desired_end_offset = 10
+         |
+         |class $dataSourceName(DataSource):
+         |    def schema(self) -> str:
+         |        return "id INT"
+         |    def simpleStreamReader(self, schema):
+         |        return SimpleDataStreamReader()
+         |""".stripMargin
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      val checkpointDir = new File(path, "checkpoint")
+      val outputDir = new File(path, "output")
+      val df = spark.readStream.format(dataSourceName).load()
+      val q = df.writeStream
+        .option("checkpointLocation", checkpointDir.getAbsolutePath)
+        .format("json")
+        .trigger(Trigger.AvailableNow())
+        .start(outputDir.getAbsolutePath)
+      q.awaitTermination(waitTimeout.toMillis)
+      val rowCount = 
spark.read.format("json").load(outputDir.getAbsolutePath).count()
+      assert(rowCount === 10)
+      checkAnswer(
+        spark.read.format("json").load(outputDir.getAbsolutePath),
+        (0 until rowCount.toInt).map(Row(_))
+      )
     }
   }
 
@@ -459,11 +514,18 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
     spark.dataSource.registerPython(dataSourceName, dataSource)
     val pythonDs = new PythonDataSourceV2
     pythonDs.setShortName("SimpleDataSource")
+
+    val options = CaseInsensitiveStringMap.empty()
+    val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
+      pythonDs, dataSourceName, inputSchema, options)
+    runner.init()
+
     val stream = new PythonMicroBatchStream(
       pythonDs,
       dataSourceName,
       inputSchema,
-      CaseInsensitiveStringMap.empty()
+      options,
+      runner
     )
 
     var startOffset = stream.initialOffset()
@@ -706,11 +768,17 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
 
     def testMicroBatchStreamError(action: String, msg: String)(
         func: PythonMicroBatchStream => Unit): Unit = {
+      val options = CaseInsensitiveStringMap.empty()
+      val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
+        pythonDs, dataSourceName, inputSchema, options)
+      runner.init()
+
       val stream = new PythonMicroBatchStream(
         pythonDs,
         errorDataSourceName,
         inputSchema,
-        CaseInsensitiveStringMap.empty()
+        options,
+        runner
       )
       val err = intercept[SparkException] {
         func(stream)
@@ -767,11 +835,17 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
 
     def testMicroBatchStreamError(action: String, msg: String)(
         func: PythonMicroBatchStream => Unit): Unit = {
+      val options = CaseInsensitiveStringMap.empty()
+      val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
+        pythonDs, dataSourceName, inputSchema, options)
+      runner.init()
+
       val stream = new PythonMicroBatchStream(
         pythonDs,
         errorDataSourceName,
         inputSchema,
-        CaseInsensitiveStringMap.empty()
+        options,
+        runner
       )
       val err = intercept[SparkException] {
         func(stream)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to