This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch 
WIP-python-data-source-admission-control-trigger-availablenow-change-the-method-signature
in repository https://gitbox.apache.org/repos/asf/spark.git

commit 35e4ec1714bf1f67008b03e0396fabae7a8b7ed8
Author: Jungtaek Lim <[email protected]>
AuthorDate: Wed Jan 21 14:18:41 2026 +0900

    address admission control and trigger availableNow for stream reader, 
built-in read limit, testing...
---
 python/pyspark/sql/datasource.py                   |  54 ------
 python/pyspark/sql/datasource_internal.py          |  46 ++++-
 python/pyspark/sql/streaming/datasource.py         | 207 +++++++++++++++++++++
 .../streaming/python_streaming_source_runner.py    |  51 +++--
 .../sql/tests/test_python_streaming_datasource.py  |  76 ++++++++
 .../v2/python/PythonMicroBatchStream.scala         |  16 ++
 .../streaming/PythonStreamingSourceRunner.scala    |  59 ++++--
 .../streaming/PythonStreamingDataSourceSuite.scala | 206 +++++++++++++++++++-
 8 files changed, 625 insertions(+), 90 deletions(-)

diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index 854f67217acf..6adfd79f2631 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -909,60 +909,6 @@ class SimpleDataSourceStreamReader(ABC):
         ...
 
 
-class ReadLimit(ABC):
-    pass
-
-
-class ReadAllAvailable(ReadLimit):
-    def __init__(self):
-        pass
-
-
-class SupportsAdmissionControl(ABC):
-    @abstractmethod
-    def latestOffset(self, start: dict, readLimit: ReadLimit) -> dict:
-        """
-        FIXME: docstring needed
-
-      /**
-       * Returns the most recent offset available given a read limit. The 
start offset can be used
-       * to figure out how much new data should be read given the limit. Users 
should implement this
-       * method instead of latestOffset for a MicroBatchStream or getOffset 
for Source.
-       * <p>
-       * When this method is called on a `Source`, the source can return 
`null` if there is no
-       * data to process. In addition, for the very first micro-batch, the 
`startOffset` will be
-       * null as well.
-       * <p>
-       * When this method is called on a MicroBatchStream, the `startOffset` 
will be `initialOffset`
-       * for the very first micro-batch. The source can return `null` if there 
is no data to process.
-       */
-        """
-        pass
-
-
-class SupportsTriggerAvailableNow(ABC):
-    @abstractmethod
-    def prepareForTriggerAvailableNow(self) -> None:
-        """
-        FIXME: docstring needed
-
-        /**
-         * This will be called at the beginning of streaming queries with 
Trigger.AvailableNow, to let the
-         * source record the offset for the current latest data at the time 
(a.k.a the target offset for
-         * the query). The source will behave as if there is no new data 
coming in after the target
-         * offset, i.e., the source will not return an offset higher than the 
target offset when
-         * {@link #latestOffset(Offset, ReadLimit) latestOffset} is called.
-         * <p>
-         * Note that there is an exception on the first uncommitted batch 
after a restart, where the end
-         * offset is not derived from the current latest offset. Sources need 
to take special
-         * considerations if wanting to assert such relation. One possible way 
is to have an internal
-         * flag in the source to indicate whether it is Trigger.AvailableNow, 
set the flag in this method,
-         * and record the target offset in the first call of
-         * {@link #latestOffset(Offset, ReadLimit) latestOffset}.
-         */
-        """
-        pass
-
 
 class DataSourceWriter(ABC):
     """
diff --git a/python/pyspark/sql/datasource_internal.py 
b/python/pyspark/sql/datasource_internal.py
index 8b0fa5eb0a5f..51ef94f9a325 100644
--- a/python/pyspark/sql/datasource_internal.py
+++ b/python/pyspark/sql/datasource_internal.py
@@ -19,16 +19,22 @@
 import json
 import copy
 from itertools import chain
-from typing import Iterator, List, Optional, Sequence, Tuple
+from typing import Iterator, List, Optional, Sequence, Tuple, Type
 
 from pyspark.sql.datasource import (
     DataSource,
     DataSourceStreamReader,
     InputPartition,
+    SimpleDataSourceStreamReader,
+)
+from pyspark.sql.streaming.datasource import (
     ReadAllAvailable,
     ReadLimit,
-    SimpleDataSourceStreamReader,
     SupportsAdmissionControl,
+    CompositeReadLimit,
+    ReadMaxBytes,
+    ReadMaxRows,
+    ReadMinRows,
 )
 from pyspark.sql.types import StructType
 from pyspark.errors import PySparkNotImplementedError
@@ -91,14 +97,9 @@ class _SimpleStreamReaderWrapper(DataSourceStreamReader, 
SupportsAdmissionContro
             self.initial_offset = self.simple_reader.initialOffset()
         return self.initial_offset
 
-    def latestOffset(self) -> dict:
-        # when query start for the first time, use initial offset as the start 
offset.
-        if self.current_offset is None:
-            self.current_offset = self.initialOffset()
-        (iter, end) = self.simple_reader.read(self.current_offset)
-        self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
-        self.current_offset = end
-        return end
+    def getDefaultReadLimit(self) -> ReadLimit:
+        # We do not consider providing different read limit on simple stream 
reader.
+        return ReadAllAvailable()
 
     def latestOffset(self, start: dict, readLimit: ReadLimit) -> dict:
         if self.current_offset is None:
@@ -163,3 +164,28 @@ class _SimpleStreamReaderWrapper(DataSourceStreamReader, 
SupportsAdmissionContro
         self, input_partition: SimpleInputPartition  # type: ignore[override]
     ) -> Iterator[Tuple]:
         return self.simple_reader.readBetweenOffsets(input_partition.start, 
input_partition.end)
+
+
+class ReadLimitRegistry:
+    def __init__(self) -> None:
+        self._registry: Dict[str, Type[ReadLimit]] = {}
+        self.__register(ReadAllAvailable.type_name(), ReadAllAvailable)
+        self.__register(ReadMinRows.type_name(), ReadMinRows)
+        self.__register(ReadMaxRows.type_name(), ReadMaxRows)
+        self.__register(ReadMaxBytes.type_name(), ReadMaxBytes)
+        self.__register(CompositeReadLimit.type_name(), CompositeReadLimit)
+
+    def __register(self, type_name: str, read_limit_type: Type["ReadLimit"]) 
-> None:
+        if type_name in self._registry:
+            # FIXME: error class?
+            raise Exception(f"ReadLimit type '{type_name}' is already 
registered.")
+
+        self._registry[type_name] = read_limit_type
+
+    def get(self, type_name: str, params: dict) -> ReadLimit:
+        read_limit_type = self._registry[type_name]
+        if read_limit_type is None:
+            raise Exception("type_name '{}' is not 
registered.".format(type_name))
+        params_without_type = params.copy()
+        del params_without_type["type"]
+        return read_limit_type.load(params_without_type)
diff --git a/python/pyspark/sql/streaming/datasource.py 
b/python/pyspark/sql/streaming/datasource.py
new file mode 100644
index 000000000000..82afbb83c583
--- /dev/null
+++ b/python/pyspark/sql/streaming/datasource.py
@@ -0,0 +1,207 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABC, abstractmethod
+from typing import List, Optional
+
+class ReadLimit(ABC):
+    @classmethod
+    @abstractmethod
+    def type_name(cls) -> str:
+        pass
+
+    @classmethod
+    @abstractmethod
+    def load(cls, params: dict) -> "ReadLimit":
+        pass
+
+    def dump(self) -> dict:
+        params = self._dump()
+        params.update({"type": self.type_name()})
+        return params
+
+    @abstractmethod
+    def _dump(self) -> dict:
+        pass
+
+
+class ReadAllAvailable(ReadLimit):
+    @classmethod
+    def type_name(cls) -> str:
+        return "ReadAllAvailable"
+
+    @classmethod
+    def load(cls, params: dict) -> "ReadAllAvailable":
+        return ReadAllAvailable()
+
+    def _dump(self) -> dict:
+        return {}
+
+
+class ReadMinRows(ReadLimit):
+    def __init__(self, min_rows: int) -> None:
+        self.min_rows = min_rows
+
+    @classmethod
+    def type_name(cls) -> str:
+        return "ReadMinRows"
+
+    @classmethod
+    def load(cls, params: dict) -> "ReadMinRows":
+        return ReadMinRows(params["min_rows"])
+
+    def _dump(self) -> dict:
+        return {"min_rows": self.min_rows}
+
+
+class ReadMaxRows(ReadLimit):
+    def __init__(self, max_rows: int) -> None:
+        self.max_rows = max_rows
+
+    @classmethod
+    def type_name(cls) -> str:
+        return "ReadMaxRows"
+
+    @classmethod
+    def load(cls, params: dict) -> "ReadMaxRows":
+        return ReadMaxRows(params["max_rows"])
+
+    def _dump(self) -> dict:
+        return {"max_rows": self.max_rows}
+
+
+class ReadMaxFiles(ReadLimit):
+    def __init__(self, max_files: int) -> None:
+        self.max_files = max_files
+
+    @classmethod
+    def type_name(cls) -> str:
+        return "ReadMaxFiles"
+
+    @classmethod
+    def load(cls, params: dict) -> "ReadMaxFiles":
+        return ReadMaxFiles(params["max_files"])
+
+    def _dump(self) -> dict:
+        return {"max_files": self.max_files}
+
+
+class ReadMaxBytes(ReadLimit):
+    def __init__(self, max_bytes: int) -> None:
+        self.max_bytes = max_bytes
+
+    @classmethod
+    def type_name(cls) -> str:
+        return "ReadMaxBytes"
+
+    @classmethod
+    def load(cls, params: dict) -> "ReadMaxBytes":
+        return ReadMaxBytes(params["max_bytes"])
+
+    def _dump(self) -> dict:
+        return {"max_bytes": self.max_bytes}
+
+
+class CompositeReadLimit(ReadLimit):
+    def __init__(self, readLimits: List[ReadLimit]) -> None:
+        self.readLimits = readLimits
+
+    @classmethod
+    def type_name(cls) -> str:
+        return "CompositeReadLimit"
+
+    @classmethod
+    def load(cls, params: dict) -> "CompositeReadLimit":
+        read_limits = []
+        for rl_params in params["readLimits"]:
+            rl_type = rl_params["type"]
+            rl = READ_LIMIT_REGISTRY.get(rl_type, rl_params)
+            read_limits.append(rl)
+        return CompositeReadLimit(read_limits)
+
+    def _dump(self) -> dict:
+        return {"readLimits": [rl.dump() for rl in self.readLimits]}
+
+
+class SupportsAdmissionControl(ABC):
+    def getDefaultReadLimit(self) -> ReadLimit:
+        """
+        FIXME: docstring needed
+
+      /**
+       * Returns the read limits potentially passed to the data source through 
options when creating
+       * the data source.
+       */
+        """
+        return ReadAllAvailable()
+
+    @abstractmethod
+    def latestOffset(self, start: dict, readLimit: ReadLimit) -> dict:
+        """
+        FIXME: docstring needed
+
+      /**
+       * Returns the most recent offset available given a read limit. The 
start offset can be used
+       * to figure out how much new data should be read given the limit. Users 
should implement this
+       * method instead of latestOffset for a MicroBatchStream or getOffset 
for Source.
+       * <p>
+       * When this method is called on a `Source`, the source can return 
`null` if there is no
+       * data to process. In addition, for the very first micro-batch, the 
`startOffset` will be
+       * null as well.
+       * <p>
+       * When this method is called on a MicroBatchStream, the `startOffset` 
will be `initialOffset`
+       * for the very first micro-batch. The source can return `null` if there 
is no data to process.
+       */
+        """
+        pass
+
+    def reportLatestOffset(self) -> Optional[dict]:
+        """
+        FIXME: docstring needed
+
+      /**
+       * Returns the most recent offset available.
+       * <p>
+       * The source can return `null`, if there is no data to process or the 
source does not support
+       * to this method.
+       */
+        """
+        return None
+
+
+class SupportsTriggerAvailableNow(ABC):
+    @abstractmethod
+    def prepareForTriggerAvailableNow(self) -> None:
+        """
+        FIXME: docstring needed
+
+        /**
+         * This will be called at the beginning of streaming queries with 
Trigger.AvailableNow, to let the
+         * source record the offset for the current latest data at the time 
(a.k.a the target offset for
+         * the query). The source will behave as if there is no new data 
coming in after the target
+         * offset, i.e., the source will not return an offset higher than the 
target offset when
+         * {@link #latestOffset(Offset, ReadLimit) latestOffset} is called.
+         * <p>
+         * Note that there is an exception on the first uncommitted batch 
after a restart, where the end
+         * offset is not derived from the current latest offset. Sources need 
to take special
+         * considerations if wanting to assert such relation. One possible way 
is to have an internal
+         * flag in the source to indicate whether it is Trigger.AvailableNow, 
set the flag in this method,
+         * and record the target offset in the first call of
+         * {@link #latestOffset(Offset, ReadLimit) latestOffset}.
+         */
+        """
+        pass
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py 
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index 54bf12843232..192f729a6fb9 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -31,11 +31,13 @@ from pyspark.serializers import (
 from pyspark.sql.datasource import (
     DataSource,
     DataSourceStreamReader,
+)
+from pyspark.sql.streaming.datasource import (
     ReadAllAvailable,
     SupportsAdmissionControl,
-    SupportsTriggerAvailableNow
+    SupportsTriggerAvailableNow,
 )
-from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper, 
_streamReader
+from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper, 
_streamReader, ReadLimitRegistry
 from pyspark.sql.pandas.serializers import ArrowStreamSerializer
 from pyspark.sql.types import (
     _parse_datatype_json_string,
@@ -60,6 +62,8 @@ COMMIT_FUNC_ID = 887
 CHECK_SUPPORTED_FEATURES_ID = 888
 PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID = 889
 LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID = 890
+GET_DEFAULT_READ_LIMIT_FUNC_ID = 891
+REPORT_LATEST_OFFSET_FUNC_ID = 892
 
 PREFETCHED_RECORDS_NOT_FOUND = 0
 NON_EMPTY_PYARROW_RECORD_BATCHES = 1
@@ -68,9 +72,12 @@ EMPTY_PYARROW_RECORD_BATCHES = 2
 SUPPORTS_ADMISSION_CONTROL = 1
 SUPPORTS_TRIGGER_AVAILABLE_NOW = 1 << 1
 
+READ_LIMIT_REGISTRY = ReadLimitRegistry()
+
 
 def initial_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None:
     offset = reader.initialOffset()
+    # raise Exception(f"Debug info for initial offset: offset: {offset}, json: 
{json.dumps(offset).encode('utf-8')}")
     write_with_length(json.dumps(offset).encode("utf-8"), outfile)
 
 
@@ -165,19 +172,37 @@ def prepare_for_trigger_available_now_func(reader, 
outfile):
 def latest_offset_admission_control_func(reader, infile, outfile):
     start_offset_dict = json.loads(utf8_deserializer.loads(infile))
 
-    limit_type = read_int(infile)
-    if limit_type == 0:
-        # ReadAllAvailable
-        limit = ReadAllAvailable()
-    else:
-        # FIXME: raise error
-        # FIXME: code for not supported?
-        raise Exception("Only ReadAllAvailable is supported for 
latestOffsetAdmissionControl.")
+    limit = json.loads(utf8_deserializer.loads(infile))
+    limit_obj = READ_LIMIT_REGISTRY.get(limit["type"], limit)
 
-    offset = reader.latestOffset(start_offset_dict, limit)
+    offset = reader.latestOffset(start_offset_dict, limit_obj)
     write_with_length(json.dumps(offset).encode("utf-8"), outfile)
 
 
+def get_default_read_limit_func(reader, outfile):
+    if isinstance(reader, SupportsAdmissionControl):
+        limit = reader.getDefaultReadLimit()
+    else:
+        limit = READ_ALL_AVAILABLE
+
+    write_with_length(json.dumps(limit.dump()).encode("utf-8"), outfile)
+
+
+def report_latest_offset_func(reader, outfile):
+    if isinstance(reader, _SimpleStreamReaderWrapper):
+        # We do not consider providing latest offset on simple stream reader.
+        write_int(0, outfile)
+    else:
+        if isinstance(reader, SupportsAdmissionControl):
+            offset = reader.reportLatestOffset()
+            if offset is None:
+                write_int(0, outfile)
+            else:
+                write_with_length(json.dumps(offset).encode("utf-8"), outfile)
+        else:
+            write_int(0, outfile)
+
+
 def main(infile: IO, outfile: IO) -> None:
     try:
         check_python_version(infile)
@@ -244,6 +269,10 @@ def main(infile: IO, outfile: IO) -> None:
                     prepare_for_trigger_available_now_func(reader, outfile)
                 elif func_id == LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID:
                     latest_offset_admission_control_func(reader, infile, 
outfile)
+                elif func_id == GET_DEFAULT_READ_LIMIT_FUNC_ID:
+                    get_default_read_limit_func(reader, outfile)
+                elif func_id == REPORT_LATEST_OFFSET_FUNC_ID:
+                    report_latest_offset_func(reader, outfile)
                 else:
                     raise IllegalArgumentException(
                         errorClass="UNSUPPORTED_OPERATION",
diff --git a/python/pyspark/sql/tests/test_python_streaming_datasource.py 
b/python/pyspark/sql/tests/test_python_streaming_datasource.py
index 1bb0c0895be9..b894a5fcfec6 100644
--- a/python/pyspark/sql/tests/test_python_streaming_datasource.py
+++ b/python/pyspark/sql/tests/test_python_streaming_datasource.py
@@ -140,6 +140,53 @@ class BasePythonStreamingDataSourceTestsMixin:
 
         return TestDataSource
 
+    def _get_test_data_source_for_admission_control(self):
+        from pyspark.sql.streaming.datasource import (
+            ReadAllAvailable,
+            ReadLimit,
+            ReadMaxRows,
+            SupportsAdmissionControl,
+        )
+
+        class TestDataStreamReader(
+            DataSourceStreamReader,
+            SupportsAdmissionControl
+        ):
+            def initialOffset(self):
+                return {"partition-1": 0}
+
+            def getDefaultReadLimit(self):
+                return ReadMaxRows(2)
+
+            def latestOffset(self, start: dict, readLimit: ReadLimit):
+                start_idx = start["partition-1"]
+                if isinstance(readLimit, ReadAllAvailable):
+                    end_offset = start_idx + 10
+                else:
+                    assert isinstance(readLimit, ReadMaxRows), ("Expected 
ReadMaxRows read limit but got "
+                                                                + 
str(type(readLimit)))
+                    end_offset = start_idx + readLimit.max_rows
+                return {"partition-1": end_offset}
+
+            def reportLatestOffset(self):
+                return {"partition-1": 1000000}
+
+            def partitions(self, start: dict, end: dict):
+                start_index = start["partition-1"]
+                end_index = end["partition-1"]
+                return [InputPartition(i) for i in range(start_index, 
end_index)]
+
+            def read(self, partition):
+                yield (partition.value,)
+
+        class TestDataSource(DataSource):
+            def schema(self) -> str:
+                return "id INT"
+            def streamReader(self, schema):
+                return TestDataStreamReader()
+
+        return TestDataSource
+
     def test_stream_reader(self):
         self.spark.dataSource.register(self._get_test_data_source())
         df = self.spark.readStream.format("TestDataSource").load()
@@ -214,6 +261,35 @@ class BasePythonStreamingDataSourceTestsMixin:
 
         assertDataFrameEqual(df, expected_data)
 
+    def test_stream_reader_admission_control_trigger_once(self):
+        
self.spark.dataSource.register(self._get_test_data_source_for_admission_control())
+        df = self.spark.readStream.format("TestDataSource").load()
+
+        def check_batch(df, batch_id):
+            assertDataFrameEqual(df, [Row(x) for x in range(10)])
+
+        q = df.writeStream.trigger(once=True).foreachBatch(check_batch).start()
+        q.awaitTermination()
+        self.assertIsNone(q.exception(), "No exception has to be propagated.")
+        self.assertTrue(q.recentProgress.length == 1)
+        self.assertTrue(q.lastProgress.numInputRows == 10)
+        self.assertTrue(q.lastProgress.sources[0].numInputRows == 10)
+        self.assertTrue(q.lastProgress.sources[0].latestOffset == 
"""{"partition-1": 1000000}""")
+
+    def test_stream_reader_admission_control_processing_time_trigger(self):
+        
self.spark.dataSource.register(self._get_test_data_source_for_admission_control())
+        df = self.spark.readStream.format("TestDataSource").load()
+
+        def check_batch(df, batch_id):
+            assertDataFrameEqual(df, [Row(batch_id * 2), Row(batch_id * 2 + 
1)])
+
+        q = df.writeStream.foreachBatch(check_batch).start()
+        while len(q.recentProgress) < 10:
+            time.sleep(0.2)
+        q.stop()
+        q.awaitTermination()
+        self.assertIsNone(q.exception(), "No exception has to be propagated.")
+
     def test_simple_stream_reader(self):
         class SimpleStreamReader(SimpleDataSourceStreamReader):
             def initialOffset(self):
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
index e4fef9e6763c..319083a9d8bb 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
@@ -29,6 +29,8 @@ import org.apache.spark.storage.{PythonStreamBlockId, 
StorageLevel}
 
 case class PythonStreamingSourceOffset(json: String) extends Offset
 
+case class PythonStreamingSourceReadLimit(json: String) extends ReadLimit
+
 class PythonMicroBatchStream(
     ds: PythonDataSourceV2,
     shortName: String,
@@ -122,6 +124,20 @@ class PythonMicroBatchStreamWithAdmissionControl(
   override def latestOffset(startOffset: Offset, limit: ReadLimit): Offset = {
     PythonStreamingSourceOffset(runner.latestOffset(startOffset, limit))
   }
+
+  override def getDefaultReadLimit: ReadLimit = {
+    val readLimitJson = runner.getDefaultReadLimit()
+    PythonStreamingSourceReadLimit(readLimitJson)
+  }
+
+  override def reportLatestOffset(): Offset = {
+    val offsetJson = runner.reportLatestOffset()
+    if (offsetJson == null) {
+      null
+    } else {
+      PythonStreamingSourceOffset(offsetJson)
+    }
+  }
 }
 
 class PythonMicroBatchStreamWithTriggerAvailableNow(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
index 36e4d09041ef..836962735d08 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
@@ -20,25 +20,23 @@ package org.apache.spark.sql.execution.python.streaming
 
 import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
 import java.nio.channels.Channels
-
 import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
-
 import org.apache.arrow.vector.ipc.ArrowStreamReader
-
 import org.apache.spark.SparkEnv
 import org.apache.spark.api.python.{PythonFunction, PythonWorker, 
PythonWorkerFactory, PythonWorkerUtils, SpecialLengths}
-import org.apache.spark.internal.{Logging, LogKeys}
+import org.apache.spark.internal.{LogKeys, Logging}
 import org.apache.spark.internal.LogKeys.PYTHON_EXEC
 import org.apache.spark.internal.config.BUFFER_SIZE
 import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.connector.read.streaming.{Offset, 
ReadAllAvailable, ReadLimit}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
+import 
org.apache.spark.sql.execution.datasources.v2.python.PythonStreamingSourceReadLimit
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.ArrowUtils
-import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, 
ColumnVector}
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, 
ColumnarBatch}
 
 object PythonStreamingSourceRunner {
   // When the python process for python_streaming_source_runner receives one 
of the
@@ -50,11 +48,14 @@ object PythonStreamingSourceRunner {
   val CHECK_SUPPORTED_FEATURES_ID = 888
   val PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID = 889
   val LATEST_OFFSET_ADMISSION_CONTROL_FUNC_ID = 890
+  val GET_DEFAULT_READ_LIMIT_FUNC_ID = 891
+  val REPORT_LATEST_OFFSET_FUNC_ID = 892
   // Status code for JVM to decide how to receive prefetched record batches
   // for simple stream reader.
   val PREFETCHED_RECORDS_NOT_FOUND = 0
   val NON_EMPTY_PYARROW_RECORD_BATCHES = 1
   val EMPTY_PYARROW_RECORD_BATCHES = 2
+  val READ_ALL_AVAILABLE_JSON = """{"type": "ReadAllAvailable"}"""
 
   case class SupportedFeatures(admissionControl: Boolean, triggerAvailableNow: 
Boolean)
 }
@@ -145,17 +146,48 @@ class PythonStreamingSourceRunner(
       throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
         action = "checkSupportedFeatures", msg)
     }
-    val admissionControl = (featureBits | (1 << 0)) == 1
-    val availableNow = (featureBits | (1 << 1)) == (1 << 1)
+    val admissionControl = (featureBits & (1 << 0)) == 1
+    val availableNow = (featureBits & (1 << 1)) == (1 << 1)
 
     SupportedFeatures(admissionControl, availableNow)
   }
 
+  def getDefaultReadLimit(): String = {
+    dataOut.writeInt(GET_DEFAULT_READ_LIMIT_FUNC_ID)
+    dataOut.flush()
+
+    val len = dataIn.readInt()
+    if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "getDefaultReadLimit", msg)
+    }
+
+    PythonWorkerUtils.readUTF(len, dataIn)
+  }
+
+  def reportLatestOffset(): String = {
+    dataOut.writeInt(REPORT_LATEST_OFFSET_FUNC_ID)
+    dataOut.flush()
+
+    val len = dataIn.readInt()
+    if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
+        action = "reportLatestOffset", msg)
+    }
+
+    if (len == 0) {
+      null
+    } else {
+      PythonWorkerUtils.readUTF(len, dataIn)
+    }
+  }
+
   def prepareForTriggerAvailableNow(): Unit = {
     dataOut.writeInt(PREPARE_FOR_TRIGGER_AVAILABLE_NOW_FUNC_ID)
     dataOut.flush()
     val status = dataIn.readInt()
-    // FIXME: code for not supported?
     if (status == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
       val msg = PythonWorkerUtils.readUTF(dataIn)
       throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError(
@@ -183,12 +215,15 @@ class PythonStreamingSourceRunner(
     PythonWorkerUtils.writeUTF(startOffset.json, dataOut)
     limit match {
       case _: ReadAllAvailable =>
-        dataOut.writeInt(0)
-        dataOut.flush()
+        // NOTE: we need to use a constant here to match the Python side given 
the engine can
+        // decide by itself to use ReadAllAvailable and the Python side 
version of the instance
+        // isn't available here.
+        PythonWorkerUtils.writeUTF(READ_ALL_AVAILABLE_JSON, dataOut)
+
+      case p: PythonStreamingSourceReadLimit =>
+        PythonWorkerUtils.writeUTF(p.json, dataOut)
 
       case _ =>
-        // FIXME: Add support for other ReadLimit types
-        // throw 
QueryExecutionErrors.unsupportedReadLimitTypeError(limit.getClass.getName)
         throw new UnsupportedOperationException("Unsupported ReadLimit type: " 
+
           s"${limit.getClass.getName}")
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
index 330a0513d360..eb5f7475636f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingDataSourceSuite.scala
@@ -350,7 +350,8 @@ class PythonStreamingDataSourceSimpleSuite extends 
PythonDataSourceSuiteBase {
     val dataSourceScript =
       s"""
          |from pyspark.sql.datasource import DataSource
-         |from pyspark.sql.datasource import SimpleDataSourceStreamReader, 
SupportsTriggerAvailableNow
+         |from pyspark.sql.datasource import SimpleDataSourceStreamReader
+         |from pyspark.sql.streaming.datasource import 
SupportsTriggerAvailableNow
          |
          |class SimpleDataStreamReader(SimpleDataSourceStreamReader, 
SupportsTriggerAvailableNow):
          |    def initialOffset(self):
@@ -673,6 +674,205 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
     q.awaitTermination()
   }
 
+  private val testAdmissionControlScript =
+    s"""
+       |from pyspark.sql.datasource import DataSource
+       |from pyspark.sql.datasource import (
+       |    DataSourceStreamReader,
+       |    InputPartition,
+       |)
+       |from pyspark.sql.streaming.datasource import (
+       |    ReadAllAvailable,
+       |    ReadLimit,
+       |    ReadMaxRows,
+       |    SupportsAdmissionControl,
+       |)
+       |
+       |class TestDataStreamReader(
+       |    DataSourceStreamReader,
+       |    SupportsAdmissionControl
+       |):
+       |    def initialOffset(self):
+       |        return {"partition-1": 0}
+       |    def getDefaultReadLimit(self):
+       |        return ReadMaxRows(2)
+       |    def latestOffset(self, start: dict, readLimit: ReadLimit):
+       |        start_idx = start["partition-1"]
+       |        if isinstance(readLimit, ReadAllAvailable):
+       |            end_offset = start_idx + 10
+       |        else:
+       |            assert isinstance(readLimit, ReadMaxRows), ("Expected 
ReadMaxRows read limit but got "
+       |                                                        + 
str(type(readLimit)))
+       |            end_offset = start_idx + readLimit.max_rows
+       |        return {"partition-1": end_offset}
+       |    def reportLatestOffset(self):
+       |        return {"partition-1": 1000000}
+       |    def partitions(self, start: dict, end: dict):
+       |        start_index = start["partition-1"]
+       |        end_index = end["partition-1"]
+       |        return [InputPartition(i) for i in range(start_index, 
end_index)]
+       |    def read(self, partition):
+       |        yield (partition.value,)
+       |
+       |class $dataSourceName(DataSource):
+       |    def schema(self) -> str:
+       |        return "id INT"
+       |    def streamReader(self, schema):
+       |        return TestDataStreamReader()
+       |""".stripMargin
+
+  private val testAvailableNowScript =
+    s"""
+       |from pyspark.sql.datasource import DataSource
+       |from pyspark.sql.datasource import (
+       |    DataSourceStreamReader,
+       |    InputPartition,
+       |)
+       |from pyspark.sql.streaming.datasource import (
+       |    ReadAllAvailable,
+       |    ReadLimit,
+       |    ReadMaxRows,
+       |    SupportsAdmissionControl,
+       |    SupportsTriggerAvailableNow
+       |)
+       |
+       |class TestDataStreamReader(
+       |    DataSourceStreamReader,
+       |    SupportsAdmissionControl,
+       |    SupportsTriggerAvailableNow
+       |):
+       |    def initialOffset(self):
+       |        return {"partition-1": 0}
+       |    def getDefaultReadLimit(self):
+       |        return ReadMaxRows(2)
+       |    def latestOffset(self, start: dict, readLimit: ReadLimit):
+       |        start_idx = start["partition-1"]
+       |        if isinstance(readLimit, ReadAllAvailable):
+       |            end_offset = start_idx + 5
+       |        else:
+       |            assert isinstance(readLimit, ReadMaxRows), ("Expected 
ReadMaxRows read limit but got "
+       |                                                        + 
str(type(readLimit)))
+       |            end_offset = start_idx + readLimit.max_rows
+       |        end_offset = min(end_offset, self.desired_end_offset)
+       |        return {"partition-1": end_offset}
+       |    def reportLatestOffset(self):
+       |        return {"partition-1": 1000000}
+       |    def prepareForTriggerAvailableNow(self):
+       |        self.desired_end_offset = 10
+       |    def partitions(self, start: dict, end: dict):
+       |        start_index = start["partition-1"]
+       |        end_index = end["partition-1"]
+       |        return [InputPartition(i) for i in range(start_index, 
end_index)]
+       |    def read(self, partition):
+       |        yield (partition.value,)
+       |
+       |class $dataSourceName(DataSource):
+       |    def schema(self) -> str:
+       |        return "id INT"
+       |    def streamReader(self, schema):
+       |        return TestDataStreamReader()
+       |""".stripMargin
+
+  test("DataSourceStreamReader with Admission Control, Trigger.Once") {
+    assume(shouldTestPandasUDFs)
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
testAdmissionControlScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      val checkpointDir = new File(path, "checkpoint")
+      val outputDir = new File(path, "output")
+      val df = spark.readStream.format(dataSourceName).load()
+      val q = df.writeStream
+        .option("checkpointLocation", checkpointDir.getAbsolutePath)
+        .format("json")
+        // Use Trigger.Once here by intention to test read with admission 
control.
+        .trigger(Trigger.Once())
+        .start(outputDir.getAbsolutePath)
+      q.awaitTermination(waitTimeout.toMillis)
+
+      assert(q.recentProgress.length === 1)
+      assert(q.lastProgress.numInputRows === 10)
+      assert(q.lastProgress.sources(0).numInputRows === 10)
+      assert(q.lastProgress.sources(0).latestOffset === """{"partition-1": 
1000000}""")
+
+      val rowCount = 
spark.read.format("json").load(outputDir.getAbsolutePath).count()
+      assert(rowCount === 10)
+      checkAnswer(
+        spark.read.format("json").load(outputDir.getAbsolutePath),
+        (0 until rowCount.toInt).map(Row(_))
+      )
+    }
+  }
+
+  test("DataSourceStreamReader with Admission Control, processing time 
trigger") {
+    assume(shouldTestPandasUDFs)
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
testAdmissionControlScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      val checkpointDir = new File(path, "checkpoint")
+      val df = spark.readStream.format(dataSourceName).load()
+
+      val stopSignal = new CountDownLatch(1)
+
+      val q = df.writeStream
+        .option("checkpointLocation", checkpointDir.getAbsolutePath)
+        .foreachBatch((df: DataFrame, batchId: Long) => {
+          df.cache()
+          checkAnswer(df, Seq(Row(batchId * 2), Row(batchId * 2 + 1)))
+          if (batchId == 10) stopSignal.countDown()
+        })
+        .trigger(Trigger.ProcessingTime(0))
+        .start()
+      stopSignal.await()
+      q.stop()
+      q.awaitTermination()
+
+      assert(q.recentProgress.length >= 10)
+      q.recentProgress.foreach { progress =>
+        assert(progress.numInputRows === 2)
+        assert(progress.sources(0).numInputRows === 2)
+        assert(progress.sources(0).latestOffset === """{"partition-1": 
1000000}""")
+      }
+    }
+  }
+
+  test("DataSourceStreamReader with Trigger.AvailableNow") {
+    assume(shouldTestPandasUDFs)
+    val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
testAvailableNowScript)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
+    withTempDir { dir =>
+      val path = dir.getAbsolutePath
+      val checkpointDir = new File(path, "checkpoint")
+      val outputDir = new File(path, "output")
+      val df = spark.readStream.format(dataSourceName).load()
+      val q = df.writeStream
+        .option("checkpointLocation", checkpointDir.getAbsolutePath)
+        .format("json")
+        .trigger(Trigger.AvailableNow())
+        .start(outputDir.getAbsolutePath)
+      q.awaitTermination(waitTimeout.toMillis)
+
+      // 2 rows * 5 batches = 10 rows
+      assert(q.recentProgress.length === 5)
+      q.recentProgress.foreach { progress =>
+        assert(progress.numInputRows === 2)
+        assert(progress.sources(0).numInputRows === 2)
+        assert(progress.sources(0).latestOffset === """{"partition-1": 
1000000}""")
+      }
+
+      val rowCount = 
spark.read.format("json").load(outputDir.getAbsolutePath).count()
+      assert(rowCount === 10)
+      checkAnswer(
+        spark.read.format("json").load(outputDir.getAbsolutePath),
+        (0 until rowCount.toInt).map(Row(_))
+      )
+    }
+  }
+
   test("Error creating stream reader") {
     assume(shouldTestPandasUDFs)
     val dataSourceScript =
@@ -770,7 +970,7 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
         func: PythonMicroBatchStream => Unit): Unit = {
       val options = CaseInsensitiveStringMap.empty()
       val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
-        pythonDs, dataSourceName, inputSchema, options)
+        pythonDs, errorDataSourceName, inputSchema, options)
       runner.init()
 
       val stream = new PythonMicroBatchStream(
@@ -837,7 +1037,7 @@ class PythonStreamingDataSourceSuite extends 
PythonDataSourceSuiteBase {
         func: PythonMicroBatchStream => Unit): Unit = {
       val options = CaseInsensitiveStringMap.empty()
       val runner = PythonMicroBatchStream.createPythonStreamingSourceRunner(
-        pythonDs, dataSourceName, inputSchema, options)
+        pythonDs, errorDataSourceName, inputSchema, options)
       runner.init()
 
       val stream = new PythonMicroBatchStream(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to