This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new c2d69d2c84 [python] Push limit down to the reader layer for PK 
merge-on-read (#7808)
c2d69d2c84 is described below

commit c2d69d2c842c7945f174cf5ce3ea6299ad40495d
Author: chaoyang <[email protected]>
AuthorDate: Thu May 14 22:49:11 2026 +0800

    [python] Push limit down to the reader layer for PK merge-on-read (#7808)
    
    PR #7742 fixed ``with_limit`` at the **scan** layer: ``TableScan`` /
    ``FileScanner`` now drop splits whose row counts exceed the budget.
    The **reader** layer, however, still drained every retained split to
    completion before the consumer trimmed the result. On PK
    merge-on-read in particular, ``with_limit(5)`` would happily merge
    hundreds or thousands of rows per split and discard all but the first
    five at ``to_arrow`` — the IO and CPU cost was effectively unbounded
    in the limit value.
    
    Same query now stops at exactly N rows. The merge pipeline gains a
    ``LimitedRecordReader`` wrapper at its outermost stage, and
    ``TableRead`` tracks a counter across splits so it stops opening
    further splits once the budget is met. The Ray path is capped on top
    with ``ds.limit(N)`` so independent workers can't collectively
    overshoot.
---
 paimon-python/pypaimon/ray/ray_paimon.py           |   7 +-
 .../pypaimon/read/datasource/ray_datasource.py     |   6 +-
 .../pypaimon/read/datasource/split_provider.py     |  19 +-
 paimon-python/pypaimon/read/read_builder.py        |   1 +
 .../pypaimon/read/reader/limited_record_reader.py  |  70 +++++++
 paimon-python/pypaimon/read/split_read.py          |   8 +-
 paimon-python/pypaimon/read/table_read.py          |  58 +++++-
 .../pypaimon/tests/py36/rest_ao_read_write_test.py |   7 +-
 .../pypaimon/tests/reader_append_only_test.py      |   7 +-
 .../pypaimon/tests/reader_primary_key_test.py      |   7 +-
 .../pypaimon/tests/rest/rest_read_write_test.py    |   7 +-
 .../pypaimon/tests/test_limit_pushdown.py          | 212 +++++++++++++++++++++
 .../pypaimon/tests/test_limited_record_reader.py   | 141 ++++++++++++++
 13 files changed, 524 insertions(+), 26 deletions(-)

diff --git a/paimon-python/pypaimon/ray/ray_paimon.py 
b/paimon-python/pypaimon/ray/ray_paimon.py
index 25c737c51e..bd81949394 100644
--- a/paimon-python/pypaimon/ray/ray_paimon.py
+++ b/paimon-python/pypaimon/ray/ray_paimon.py
@@ -92,13 +92,18 @@ def read_paimon(
             tag_name=tag_name,
         )
     )
-    return ray.data.read_datasource(
+    ds = ray.data.read_datasource(
         datasource,
         ray_remote_args=ray_remote_args,
         concurrency=concurrency,
         override_num_blocks=override_num_blocks,
         **read_args,
     )
+    # Per-task limit short-circuits each worker's reader, but N workers
+    # could collectively overshoot the user-visible limit. Cap on top.
+    if limit is not None:
+        ds = ds.limit(limit)
+    return ds
 
 
 def write_paimon(
diff --git a/paimon-python/pypaimon/read/datasource/ray_datasource.py 
b/paimon-python/pypaimon/read/datasource/ray_datasource.py
index 0e8a0836dc..25c6109259 100644
--- a/paimon-python/pypaimon/read/datasource/ray_datasource.py
+++ b/paimon-python/pypaimon/read/datasource/ray_datasource.py
@@ -125,6 +125,7 @@ class RayDatasource(Datasource):
         predicate = self._split_provider.predicate()
         read_type = self._split_provider.read_type()
         splits = self._split_provider.splits()
+        limit = self._split_provider.limit()
         if not splits:
             return []
 
@@ -146,10 +147,12 @@ class RayDatasource(Datasource):
                 predicate=predicate,
                 read_type=read_type,
                 schema=schema,
+                limit=limit,
         ) -> Iterable[pyarrow.Table]:
             """Read function that will be executed by Ray workers."""
             from pypaimon.read.table_read import TableRead
-            worker_table_read = TableRead(table, predicate, read_type)
+            worker_table_read = TableRead(
+                table, predicate, read_type, limit=limit)
 
             batch_reader = worker_table_read.to_arrow_batch_reader(splits)
             has_data = False
@@ -175,6 +178,7 @@ class RayDatasource(Datasource):
             predicate=predicate,
             read_type=read_type,
             schema=schema,
+            limit=limit,
         )
 
         read_tasks = []
diff --git a/paimon-python/pypaimon/read/datasource/split_provider.py 
b/paimon-python/pypaimon/read/datasource/split_provider.py
index c63b792a4a..e981eebb96 100644
--- a/paimon-python/pypaimon/read/datasource/split_provider.py
+++ b/paimon-python/pypaimon/read/datasource/split_provider.py
@@ -58,6 +58,15 @@ class SplitProvider(ABC):
         to peek at concrete provider types to format its name.
         """
 
+    def limit(self) -> Optional[int]:
+        """Optional row limit applied at scan/read time.
+
+        Subclasses override when the limit is known up front so the
+        datasource can thread it through to per-task ``TableRead``
+        instances and stop reading once the budget is met.
+        """
+        return None
+
 
 class CatalogSplitProvider(SplitProvider):
     """Plan splits from a fully-qualified table identifier and catalog options.
@@ -143,6 +152,9 @@ class CatalogSplitProvider(SplitProvider):
     def predicate(self):
         return self._predicate
 
+    def limit(self) -> Optional[int]:
+        return self._limit
+
     def display_name(self) -> str:
         return self._table_identifier
 
@@ -155,11 +167,13 @@ class PreResolvedSplitProvider(SplitProvider):
     skipped.
     """
 
-    def __init__(self, table, splits: List[Split], read_type, predicate=None):
+    def __init__(self, table, splits: List[Split], read_type, predicate=None,
+                 limit: Optional[int] = None):
         self._table = table
         self._splits = splits
         self._read_type = read_type
         self._predicate = predicate
+        self._limit = limit
 
     def table(self):
         return self._table
@@ -173,6 +187,9 @@ class PreResolvedSplitProvider(SplitProvider):
     def predicate(self):
         return self._predicate
 
+    def limit(self) -> Optional[int]:
+        return self._limit
+
     def display_name(self) -> str:
         identifier = self._table.identifier
         if hasattr(identifier, 'get_full_name'):
diff --git a/paimon-python/pypaimon/read/read_builder.py 
b/paimon-python/pypaimon/read/read_builder.py
index 3bb66b7602..13a951df0b 100644
--- a/paimon-python/pypaimon/read/read_builder.py
+++ b/paimon-python/pypaimon/read/read_builder.py
@@ -79,6 +79,7 @@ class ReadBuilder:
             predicate=self._predicate,
             read_type=self.read_type(),
             nested_name_paths=self._nested_name_paths(),
+            limit=self._limit,
         )
 
     def _nested_name_paths(self) -> Optional[List[List[str]]]:
diff --git a/paimon-python/pypaimon/read/reader/limited_record_reader.py 
b/paimon-python/pypaimon/read/reader/limited_record_reader.py
new file mode 100644
index 0000000000..74f2612ebd
--- /dev/null
+++ b/paimon-python/pypaimon/read/reader/limited_record_reader.py
@@ -0,0 +1,70 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+"""Row-level limit wrapper for any ``RecordReader`` chain.
+
+Currently used at the outermost stage of the PK merge-on-read pipeline
+so the merge output is short-circuited at the row level instead of
+running to completion, but the wrapper itself is generic and may be
+reused on other reader chains.
+"""
+
+from typing import Optional
+
+from pypaimon.read.reader.iface.record_iterator import RecordIterator
+from pypaimon.read.reader.iface.record_reader import RecordReader
+
+
+class LimitedRecordReader(RecordReader):
+    """Stop emitting rows once ``limit`` rows have been delivered."""
+
+    def __init__(self, inner: RecordReader, limit: int):
+        if limit < 0:
+            raise ValueError("limit must be non-negative, got %d" % limit)
+        self._inner = inner
+        self._limit = limit
+        # Public so the iterator can read/write the shared counter without
+        # going through accessor calls per row.
+        self.count = 0
+
+    def read_batch(self) -> Optional[RecordIterator]:
+        if self.count >= self._limit:
+            return None
+        batch = self._inner.read_batch()
+        if batch is None:
+            return None
+        return _LimitedRecordIterator(batch, self)
+
+    def close(self) -> None:
+        self._inner.close()
+
+
+class _LimitedRecordIterator(RecordIterator):
+
+    def __init__(self, inner: RecordIterator, limiter: LimitedRecordReader):
+        self._inner = inner
+        self._limiter = limiter
+
+    def next(self):
+        if self._limiter.count >= self._limiter._limit:
+            return None
+        row = self._inner.next()
+        if row is None:
+            return None
+        self._limiter.count += 1
+        return row
diff --git a/paimon-python/pypaimon/read/split_read.py 
b/paimon-python/pypaimon/read/split_read.py
index ed3689b14b..8a203c9f4c 100644
--- a/paimon-python/pypaimon/read/split_read.py
+++ b/paimon-python/pypaimon/read/split_read.py
@@ -577,7 +577,8 @@ class MergeFileSplitRead(SplitRead):
             read_type: List[DataField],
             split: Split,
             row_tracking_enabled: bool,
-            outer_extract_name_paths: Optional[List[List[str]]] = None):
+            outer_extract_name_paths: Optional[List[List[str]]] = None,
+            limit: Optional[int] = None):
         # Merge functions need full ROW sub-structures, so nested paths
         # are not pushed down here; sub-path extraction happens above
         # the merge via OuterProjectionRecordReader.
@@ -590,6 +591,7 @@ class MergeFileSplitRead(SplitRead):
             nested_name_paths=None,
         )
         self.outer_extract_name_paths = outer_extract_name_paths
+        self.limit = limit
 
     def kv_reader_supplier(self, file: DataFileMeta, dv_factory: 
Optional[Callable] = None) -> RecordReader:
         file_batch_reader = self.file_reader_supplier(file, True, 
self._get_final_read_data_fields(), False)
@@ -631,6 +633,10 @@ class MergeFileSplitRead(SplitRead):
             inner_top_names = [f.name for f in 
self.read_fields[-self.value_arity:]]
             reader = OuterProjectionRecordReader(
                 reader, inner_top_names, self.outer_extract_name_paths)
+        if self.limit is not None:
+            from pypaimon.read.reader.limited_record_reader import \
+                LimitedRecordReader
+            reader = LimitedRecordReader(reader, self.limit)
         return reader
 
     def _get_all_data_fields(self):
diff --git a/paimon-python/pypaimon/read/table_read.py 
b/paimon-python/pypaimon/read/table_read.py
index 1f22d4354f..c45f2a8532 100644
--- a/paimon-python/pypaimon/read/table_read.py
+++ b/paimon-python/pypaimon/read/table_read.py
@@ -42,6 +42,7 @@ class TableRead:
         read_type: List[DataField],
         include_row_kind: bool = False,
         nested_name_paths: Optional[List[List[str]]] = None,
+        limit: Optional[int] = None,
     ):
         from pypaimon.table.file_store_table import FileStoreTable
 
@@ -50,14 +51,24 @@ class TableRead:
         self.read_type = read_type
         self.include_row_kind = include_row_kind
         self.nested_name_paths = nested_name_paths
+        self.limit = limit
 
     def to_iterator(self, splits: List[Split]) -> Iterator:
+        limit = self.limit
+
         def _record_generator():
+            count = 0
             for split in splits:
+                if limit is not None and count >= limit:
+                    return
                 reader = self._create_split_read(split).create_reader()
                 try:
                     for batch in iter(reader.read_batch, None):
-                        yield from iter(batch.next, None)
+                        for row in iter(batch.next, None):
+                            yield row
+                            count += 1
+                            if limit is not None and count >= limit:
+                                return
                 finally:
                     reader.close()
 
@@ -113,21 +124,34 @@ class TableRead:
 
     def _arrow_batch_generator(self, splits: List[Split], schema: 
pyarrow.Schema) -> Iterator[pyarrow.RecordBatch]:
         chunk_size = 65536
+        # ``remaining`` tracks how many rows we are still allowed to emit
+        # across all splits. ``None`` means unlimited.
+        remaining = self.limit
 
         for split in splits:
+            if remaining is not None and remaining <= 0:
+                break
             reader = self._create_split_read(split).create_reader()
             try:
                 if isinstance(reader, RecordBatchReader):
-                    # Add row kind column if requested (default to +I for 
RecordBatchReader)
-                    if self.include_row_kind:
-                        for batch in iter(reader.read_arrow_batch, None):
-                            yield self._add_row_kind_column_to_batch(batch, 
"+I")
-                    else:
-                        yield from iter(reader.read_arrow_batch, None)
+                    for batch in iter(reader.read_arrow_batch, None):
+                        if remaining is not None and batch.num_rows > 
remaining:
+                            batch = batch.slice(0, remaining)
+                        if self.include_row_kind:
+                            batch = self._add_row_kind_column_to_batch(batch, 
"+I")
+                        yield batch
+                        if remaining is not None:
+                            remaining -= batch.num_rows
+                            if remaining <= 0:
+                                break
                 else:
                     row_tuple_chunk = []
                     row_kind_chunk = []
-                    for row_iterator in iter(reader.read_batch, None):
+                    while True:
+                        row_iterator = reader.read_batch()
+                        if row_iterator is None:
+                            break
+                        stop = False
                         for row in iter(row_iterator.next, None):
                             if not isinstance(row, OffsetRow):
                                 raise TypeError(f"Expected OffsetRow, but got 
{type(row).__name__}")
@@ -135,6 +159,12 @@ class TableRead:
                             if self.include_row_kind:
                                 
row_kind_chunk.append(row.get_row_kind().to_string())
 
+                            if remaining is not None:
+                                remaining -= 1
+                                if remaining <= 0:
+                                    stop = True
+                                    break
+
                             if len(row_tuple_chunk) >= chunk_size:
                                 batch = 
self._convert_rows_to_arrow_batch_with_row_kind(
                                     row_tuple_chunk, row_kind_chunk, schema
@@ -142,6 +172,8 @@ class TableRead:
                                 yield batch
                                 row_tuple_chunk = []
                                 row_kind_chunk = []
+                        if stop:
+                            break
 
                     if row_tuple_chunk:
                         batch = 
self._convert_rows_to_arrow_batch_with_row_kind(
@@ -242,15 +274,22 @@ class TableRead:
                 splits=splits,
                 read_type=self.read_type,
                 predicate=self.predicate,
+                limit=self.limit,
             )
         )
-        return ray.data.read_datasource(
+        ds = ray.data.read_datasource(
             datasource,
             ray_remote_args=ray_remote_args,
             concurrency=concurrency,
             override_num_blocks=override_num_blocks,
             **read_args
         )
+        # Each Ray worker applies the per-task limit independently, so N
+        # workers can collectively yield up to N * limit rows. Cap the
+        # final dataset to the user-visible limit on top.
+        if self.limit is not None:
+            ds = ds.limit(self.limit)
+        return ds
 
     def to_torch(
         self,
@@ -285,6 +324,7 @@ class TableRead:
                 split=split,
                 row_tracking_enabled=False,
                 outer_extract_name_paths=outer_extract_name_paths,
+                limit=self.limit,
             )
         elif self.table.options.data_evolution_enabled():
             if self.nested_name_paths and any(
diff --git a/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py 
b/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py
index fcbcf558ae..a0a3ad37e8 100644
--- a/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py
+++ b/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py
@@ -613,11 +613,12 @@ class RESTAOReadWritePy36Test(RESTBaseTest):
         table = self.rest_catalog.get_table('default.test_append_only_limit')
         self._write_test_table(table)
 
+        # Row-level limit: the reader stops at exactly N rows (not "first
+        # split's full row count"). Scan still keeps the first split that
+        # covers the limit; the reader short-circuits inside it.
         read_builder = table.new_read_builder().with_limit(1)
         actual = self._read_test_table(read_builder)
-        # only records from 1st commit (1st split) will be read
-        # might be split of "dt=1" or split of "dt=2"
-        self.assertEqual(actual.num_rows, 4)
+        self.assertEqual(actual.num_rows, 1)
 
     def test_write_wrong_schema(self):
         self.rest_catalog.create_table('default.test_wrong_schema',
diff --git a/paimon-python/pypaimon/tests/reader_append_only_test.py 
b/paimon-python/pypaimon/tests/reader_append_only_test.py
index 8006d122d9..5cb8a2dfca 100644
--- a/paimon-python/pypaimon/tests/reader_append_only_test.py
+++ b/paimon-python/pypaimon/tests/reader_append_only_test.py
@@ -658,11 +658,12 @@ class AoReaderTest(unittest.TestCase):
         table = self.catalog.get_table('default.test_append_only_limit')
         self._write_test_table(table)
 
+        # Row-level limit: the reader stops at exactly N rows (not "first
+        # split's full row count"). Scan still keeps the first split that
+        # covers the limit; the reader short-circuits inside it.
         read_builder = table.new_read_builder().with_limit(1)
         actual = self._read_test_table(read_builder)
-        # only records from 1st commit (1st split) will be read
-        # might be split of "dt=1" or split of "dt=2"
-        self.assertEqual(actual.num_rows, 4)
+        self.assertEqual(actual.num_rows, 1)
 
     def test_incremental_timestamp(self):
         schema = Schema.from_pyarrow_schema(self.pa_schema, 
partition_keys=['dt'])
diff --git a/paimon-python/pypaimon/tests/reader_primary_key_test.py 
b/paimon-python/pypaimon/tests/reader_primary_key_test.py
index adc8c9cf9b..89a9e8a6a1 100644
--- a/paimon-python/pypaimon/tests/reader_primary_key_test.py
+++ b/paimon-python/pypaimon/tests/reader_primary_key_test.py
@@ -350,7 +350,6 @@ class PkReaderTest(unittest.TestCase):
             len(merge_splits), 0,
             "Should have at least one merge split to test limit with merge 
scenario")
 
-        total_unique_rows = 125
         for limit in [5, 10, 20, 50]:
             read_builder = table.new_read_builder().with_limit(limit)
             table_read = read_builder.new_read()
@@ -362,9 +361,9 @@ class PkReaderTest(unittest.TestCase):
             result = table_read.to_arrow(splits)
             row_count = result.num_rows if result is not None else 0
             self.assertEqual(
-                row_count, total_unique_rows,
-                f"with_limit({limit}) should return all rows for PK table "
-                f"(read-level limit not yet implemented)")
+                row_count, limit,
+                f"with_limit({limit}) on PK table must return exactly "
+                f"{limit} rows now that the read-level limit is wired")
 
     def test_incremental_timestamp(self):
         schema = Schema.from_pyarrow_schema(self.pa_schema,
diff --git a/paimon-python/pypaimon/tests/rest/rest_read_write_test.py 
b/paimon-python/pypaimon/tests/rest/rest_read_write_test.py
index bd357f907a..f46ce5444b 100644
--- a/paimon-python/pypaimon/tests/rest/rest_read_write_test.py
+++ b/paimon-python/pypaimon/tests/rest/rest_read_write_test.py
@@ -302,11 +302,12 @@ class RESTTableReadWriteTest(RESTBaseTest):
         table = self.rest_catalog.get_table('default.test_append_only_limit')
         self._write_test_table(table)
 
+        # Row-level limit: the reader stops at exactly N rows (not "first
+        # split's full row count"). Scan still keeps the first split that
+        # covers the limit; the reader short-circuits inside it.
         read_builder = table.new_read_builder().with_limit(1)
         actual = self._read_test_table(read_builder)
-        # only records from 1st commit (1st split) will be read
-        # might be split of "dt=1" or split of "dt=2"
-        self.assertEqual(actual.num_rows, 4)
+        self.assertEqual(actual.num_rows, 1)
 
     def test_pk_parquet_reader(self):
         schema = Schema.from_pyarrow_schema(self.pa_schema,
diff --git a/paimon-python/pypaimon/tests/test_limit_pushdown.py 
b/paimon-python/pypaimon/tests/test_limit_pushdown.py
new file mode 100644
index 0000000000..2e717c28c7
--- /dev/null
+++ b/paimon-python/pypaimon/tests/test_limit_pushdown.py
@@ -0,0 +1,212 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+"""End-to-end coverage for ``with_limit`` after row-level pushdown.
+
+Locks the contract: ``with_limit(N)`` returns at most ``N`` rows, and
+the reader actually stops at that boundary instead of reading every
+split / merge output to completion and trimming at the consumer.
+"""
+
+import os
+import shutil
+import tempfile
+import unittest
+
+import pyarrow as pa
+
+from pypaimon import CatalogFactory, Schema
+
+
+class LimitPushdownTest(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        cls.tempdir = tempfile.mkdtemp()
+        cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+        cls.catalog_options = {'warehouse': cls.warehouse}
+        cls.catalog = CatalogFactory.create(cls.catalog_options)
+        cls.catalog.create_database('default', False)
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+    @staticmethod
+    def _ao_schema() -> pa.Schema:
+        return pa.schema([
+            pa.field('id', pa.int64(), nullable=False),
+            ('val', pa.int64()),
+            pa.field('dt', pa.string(), nullable=False),
+        ])
+
+    @staticmethod
+    def _pk_schema() -> pa.Schema:
+        return pa.schema([
+            pa.field('id', pa.int64(), nullable=False),
+            ('val', pa.int64()),
+        ])
+
+    def _create_ao_table(self, name: str):
+        identifier = 'default.' + name
+        schema = Schema.from_pyarrow_schema(
+            self._ao_schema(),
+            partition_keys=['dt'],
+            options={'file.format': 'parquet'},
+        )
+        self.catalog.create_table(identifier, schema, False)
+        return self.catalog.get_table(identifier)
+
+    def _create_pk_table(self, name: str, *, num_buckets: int = 1):
+        identifier = 'default.' + name
+        schema = Schema.from_pyarrow_schema(
+            self._pk_schema(),
+            primary_keys=['id'],
+            options={'bucket': str(num_buckets), 'file.format': 'parquet'},
+        )
+        self.catalog.create_table(identifier, schema, False)
+        return self.catalog.get_table(identifier)
+
+    def _write_ao_partitions(self, table, partitions):
+        for dt, rows in partitions:
+            wb = table.new_batch_write_builder()
+            w = wb.new_write()
+            data = pa.Table.from_pylist(
+                [{'id': r, 'val': r * 10, 'dt': dt} for r in rows],
+                schema=self._ao_schema())
+            w.write_arrow(data)
+            wb.new_commit().commit(w.prepare_commit())
+            w.close()
+
+    def _write_pk_snapshots(self, table, snapshots):
+        for rows in snapshots:
+            wb = table.new_batch_write_builder()
+            w = wb.new_write()
+            data = pa.Table.from_pylist(
+                [{'id': i, 'val': v} for i, v in rows], 
schema=self._pk_schema())
+            w.write_arrow(data)
+            wb.new_commit().commit(w.prepare_commit())
+            w.close()
+
+    # ---- append-only -----------------------------------------------------
+
+    def test_append_only_limit_stops_within_first_split(self):
+        """With limit=3 on a partitioned append-only table, the result is
+        exactly 3 rows — even though each partition split has 5 rows."""
+        table = self._create_ao_table('limit_ao_within_split')
+        self._write_ao_partitions(table, [
+            ('p1', list(range(5))),       # 5 rows
+            ('p2', list(range(5, 10))),   # 5 rows
+        ])
+        rb = table.new_read_builder().with_limit(3)
+        result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+        self.assertEqual(result.num_rows, 3)
+
+    def test_append_only_limit_spans_multiple_splits(self):
+        """Limit larger than first split: read carries over to the next
+        split until the budget is met."""
+        table = self._create_ao_table('limit_ao_span_splits')
+        self._write_ao_partitions(table, [
+            ('p1', [1, 2]),
+            ('p2', [3, 4]),
+            ('p3', [5, 6]),
+        ])
+        rb = table.new_read_builder().with_limit(5)
+        result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+        self.assertEqual(result.num_rows, 5)
+
+    def test_append_only_limit_zero_returns_empty(self):
+        table = self._create_ao_table('limit_ao_zero')
+        self._write_ao_partitions(table, [('p1', [1, 2, 3])])
+        rb = table.new_read_builder().with_limit(0)
+        splits = rb.new_scan().plan().splits()
+        result = rb.new_read().to_arrow(splits)
+        self.assertEqual(result.num_rows, 0)
+
+    def test_append_only_limit_larger_than_total(self):
+        """Limit greater than the total returns the total, not the limit."""
+        table = self._create_ao_table('limit_ao_oversize')
+        self._write_ao_partitions(table, [('p1', [1, 2, 3])])
+        rb = table.new_read_builder().with_limit(100)
+        result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+        self.assertEqual(result.num_rows, 3)
+
+    # ---- PK merge-on-read ------------------------------------------------
+
+    def test_pk_merge_limit_stops_within_first_split(self):
+        """PK + multiple snapshots forces the merge-read path. The reader
+        must stop at limit rows instead of running every section to
+        completion and trimming at the consumer."""
+        table = self._create_pk_table('limit_pk_within_split')
+        # Two snapshots over the same key range → merge path; total
+        # post-merge unique rows = 20.
+        self._write_pk_snapshots(table, [
+            [(i, i) for i in range(20)],
+            [(i, i + 1000) for i in range(0, 20, 2)],
+        ])
+        for limit in (1, 5, 10, 19):
+            rb = table.new_read_builder().with_limit(limit)
+            result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+            self.assertEqual(
+                result.num_rows, limit,
+                "with_limit(%d) must short-circuit at the row level" % limit)
+
+    def test_pk_merge_limit_equals_total(self):
+        """Limit equal to total post-merge row count: returns everything."""
+        table = self._create_pk_table('limit_pk_equals_total')
+        self._write_pk_snapshots(table, [
+            [(i, i) for i in range(10)],
+            [(i, i + 100) for i in range(5)],
+        ])
+        rb = table.new_read_builder().with_limit(10)
+        result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+        self.assertEqual(result.num_rows, 10)
+
+    def test_pk_merge_limit_with_predicate(self):
+        """``with_limit`` plus ``with_filter``: the filter prunes first and
+        the limit caps what survives. ``val >= 1000`` matches the latest
+        write of the even ``id`` rows; limit then takes the prefix."""
+        table = self._create_pk_table('limit_pk_with_filter')
+        self._write_pk_snapshots(table, [
+            [(i, i) for i in range(20)],
+            [(i, i + 1000) for i in range(0, 20, 2)],  # update evens
+        ])
+        rb = table.new_read_builder()
+        pred = rb.new_predicate_builder().greater_or_equal('val', 1000)
+        rb = rb.with_filter(pred).with_limit(3)
+        result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+        self.assertEqual(result.num_rows, 3)
+        for v in result.column('val').to_pylist():
+            self.assertGreaterEqual(v, 1000)
+
+    # ---- to_iterator path ------------------------------------------------
+
+    def test_to_iterator_limit_short_circuits(self):
+        table = self._create_ao_table('limit_iter')
+        self._write_ao_partitions(table, [
+            ('p1', list(range(50))),
+            ('p2', list(range(50, 100))),
+        ])
+        rb = table.new_read_builder().with_limit(7)
+        it = rb.new_read().to_iterator(rb.new_scan().plan().splits())
+        rows = list(it)
+        self.assertEqual(len(rows), 7)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/paimon-python/pypaimon/tests/test_limited_record_reader.py 
b/paimon-python/pypaimon/tests/test_limited_record_reader.py
new file mode 100644
index 0000000000..edbc1b75f4
--- /dev/null
+++ b/paimon-python/pypaimon/tests/test_limited_record_reader.py
@@ -0,0 +1,141 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import unittest
+from typing import List, Optional
+
+from pypaimon.read.reader.iface.record_iterator import RecordIterator
+from pypaimon.read.reader.iface.record_reader import RecordReader
+from pypaimon.read.reader.limited_record_reader import LimitedRecordReader
+
+
+class _ListIterator(RecordIterator):
+    def __init__(self, items: List):
+        self._items = items
+        self._idx = 0
+
+    def next(self):
+        if self._idx >= len(self._items):
+            return None
+        v = self._items[self._idx]
+        self._idx += 1
+        return v
+
+
+class _StaticReader(RecordReader):
+    """Hands back batches one at a time, tracks close calls and the
+    number of times ``read_batch`` was invoked (so tests can prove the
+    limiter actually short-circuited instead of draining the inner)."""
+
+    def __init__(self, batches: List[List]):
+        self._batches = batches
+        self._idx = 0
+        self.closed = False
+        self.read_batch_calls = 0
+
+    def read_batch(self) -> Optional[RecordIterator]:
+        self.read_batch_calls += 1
+        if self._idx >= len(self._batches):
+            return None
+        batch = self._batches[self._idx]
+        self._idx += 1
+        return _ListIterator(batch)
+
+    def close(self):
+        self.closed = True
+
+
+def _drain(reader: RecordReader) -> List:
+    out = []
+    while True:
+        batch = reader.read_batch()
+        if batch is None:
+            break
+        while True:
+            v = batch.next()
+            if v is None:
+                break
+            out.append(v)
+    return out
+
+
+class LimitedRecordReaderTest(unittest.TestCase):
+
+    def test_limit_within_first_batch(self):
+        reader = LimitedRecordReader(
+            _StaticReader([[1, 2, 3, 4, 5]]), limit=3)
+        self.assertEqual(_drain(reader), [1, 2, 3])
+
+    def test_limit_spans_multiple_batches(self):
+        reader = LimitedRecordReader(
+            _StaticReader([[1, 2], [3, 4], [5, 6]]), limit=5)
+        self.assertEqual(_drain(reader), [1, 2, 3, 4, 5])
+
+    def test_limit_larger_than_total_returns_everything(self):
+        reader = LimitedRecordReader(
+            _StaticReader([[1, 2, 3]]), limit=999)
+        self.assertEqual(_drain(reader), [1, 2, 3])
+
+    def test_limit_zero_returns_nothing(self):
+        reader = LimitedRecordReader(
+            _StaticReader([[1, 2, 3]]), limit=0)
+        self.assertEqual(_drain(reader), [])
+        # read_batch should short-circuit immediately rather than peek.
+        self.assertIsNone(reader.read_batch())
+
+    def test_negative_limit_rejected(self):
+        with self.assertRaises(ValueError):
+            LimitedRecordReader(_StaticReader([]), limit=-1)
+
+    def test_close_propagates(self):
+        inner = _StaticReader([[1, 2]])
+        reader = LimitedRecordReader(inner, limit=10)
+        reader.close()
+        self.assertTrue(inner.closed)
+
+    def test_iterator_stops_mid_batch(self):
+        # Limit cuts halfway through a batch; the next() call past the limit
+        # must return None even though the inner batch still has items.
+        reader = LimitedRecordReader(
+            _StaticReader([[1, 2, 3, 4, 5]]), limit=2)
+        batch = reader.read_batch()
+        self.assertEqual(batch.next(), 1)
+        self.assertEqual(batch.next(), 2)
+        self.assertIsNone(batch.next())
+        # Subsequent read_batch is also None.
+        self.assertIsNone(reader.read_batch())
+
+    def test_count_visible_for_observability(self):
+        reader = LimitedRecordReader(
+            _StaticReader([[1, 2, 3, 4]]), limit=10)
+        _drain(reader)
+        self.assertEqual(reader.count, 4)
+
+    def test_does_not_drain_inner_when_limit_met_within_first_batch(self):
+        """Direct proof of the short-circuit: once the limiter has handed
+        out ``limit`` rows the next ``read_batch`` short-circuits at the
+        entry guard and never pulls a second batch from the inner."""
+        inner = _StaticReader([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
+        reader = LimitedRecordReader(inner, limit=3)
+        self.assertEqual(_drain(reader), [1, 2, 3])
+        # Only the first batch was fetched; the second is never asked for.
+        self.assertEqual(inner.read_batch_calls, 1)
+
+
+if __name__ == '__main__':
+    unittest.main()


Reply via email to