This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new c2d69d2c84 [python] Push limit down to the reader layer for PK
merge-on-read (#7808)
c2d69d2c84 is described below
commit c2d69d2c842c7945f174cf5ce3ea6299ad40495d
Author: chaoyang <[email protected]>
AuthorDate: Thu May 14 22:49:11 2026 +0800
[python] Push limit down to the reader layer for PK merge-on-read (#7808)
PR #7742 fixed ``with_limit`` at the **scan** layer: ``TableScan`` /
``FileScanner`` now drop splits whose row counts exceed the budget.
The **reader** layer, however, still drained every retained split to
completion before the consumer trimmed the result. On PK
merge-on-read in particular, ``with_limit(5)`` would happily merge
hundreds or thousands of rows per split and discard all but the first
five at ``to_arrow`` — the IO and CPU cost was effectively unbounded
in the limit value.
Same query now stops at exactly N rows. The merge pipeline gains a
``LimitedRecordReader`` wrapper at its outermost stage, and
``TableRead`` tracks a counter across splits so it stops opening
further splits once the budget is met. The Ray path is capped on top
with ``ds.limit(N)`` so independent workers can't collectively
overshoot.
---
paimon-python/pypaimon/ray/ray_paimon.py | 7 +-
.../pypaimon/read/datasource/ray_datasource.py | 6 +-
.../pypaimon/read/datasource/split_provider.py | 19 +-
paimon-python/pypaimon/read/read_builder.py | 1 +
.../pypaimon/read/reader/limited_record_reader.py | 70 +++++++
paimon-python/pypaimon/read/split_read.py | 8 +-
paimon-python/pypaimon/read/table_read.py | 58 +++++-
.../pypaimon/tests/py36/rest_ao_read_write_test.py | 7 +-
.../pypaimon/tests/reader_append_only_test.py | 7 +-
.../pypaimon/tests/reader_primary_key_test.py | 7 +-
.../pypaimon/tests/rest/rest_read_write_test.py | 7 +-
.../pypaimon/tests/test_limit_pushdown.py | 212 +++++++++++++++++++++
.../pypaimon/tests/test_limited_record_reader.py | 141 ++++++++++++++
13 files changed, 524 insertions(+), 26 deletions(-)
diff --git a/paimon-python/pypaimon/ray/ray_paimon.py
b/paimon-python/pypaimon/ray/ray_paimon.py
index 25c737c51e..bd81949394 100644
--- a/paimon-python/pypaimon/ray/ray_paimon.py
+++ b/paimon-python/pypaimon/ray/ray_paimon.py
@@ -92,13 +92,18 @@ def read_paimon(
tag_name=tag_name,
)
)
- return ray.data.read_datasource(
+ ds = ray.data.read_datasource(
datasource,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
override_num_blocks=override_num_blocks,
**read_args,
)
+ # Per-task limit short-circuits each worker's reader, but N workers
+ # could collectively overshoot the user-visible limit. Cap on top.
+ if limit is not None:
+ ds = ds.limit(limit)
+ return ds
def write_paimon(
diff --git a/paimon-python/pypaimon/read/datasource/ray_datasource.py
b/paimon-python/pypaimon/read/datasource/ray_datasource.py
index 0e8a0836dc..25c6109259 100644
--- a/paimon-python/pypaimon/read/datasource/ray_datasource.py
+++ b/paimon-python/pypaimon/read/datasource/ray_datasource.py
@@ -125,6 +125,7 @@ class RayDatasource(Datasource):
predicate = self._split_provider.predicate()
read_type = self._split_provider.read_type()
splits = self._split_provider.splits()
+ limit = self._split_provider.limit()
if not splits:
return []
@@ -146,10 +147,12 @@ class RayDatasource(Datasource):
predicate=predicate,
read_type=read_type,
schema=schema,
+ limit=limit,
) -> Iterable[pyarrow.Table]:
"""Read function that will be executed by Ray workers."""
from pypaimon.read.table_read import TableRead
- worker_table_read = TableRead(table, predicate, read_type)
+ worker_table_read = TableRead(
+ table, predicate, read_type, limit=limit)
batch_reader = worker_table_read.to_arrow_batch_reader(splits)
has_data = False
@@ -175,6 +178,7 @@ class RayDatasource(Datasource):
predicate=predicate,
read_type=read_type,
schema=schema,
+ limit=limit,
)
read_tasks = []
diff --git a/paimon-python/pypaimon/read/datasource/split_provider.py
b/paimon-python/pypaimon/read/datasource/split_provider.py
index c63b792a4a..e981eebb96 100644
--- a/paimon-python/pypaimon/read/datasource/split_provider.py
+++ b/paimon-python/pypaimon/read/datasource/split_provider.py
@@ -58,6 +58,15 @@ class SplitProvider(ABC):
to peek at concrete provider types to format its name.
"""
+ def limit(self) -> Optional[int]:
+ """Optional row limit applied at scan/read time.
+
+ Subclasses override when the limit is known up front so the
+ datasource can thread it through to per-task ``TableRead``
+ instances and stop reading once the budget is met.
+ """
+ return None
+
class CatalogSplitProvider(SplitProvider):
"""Plan splits from a fully-qualified table identifier and catalog options.
@@ -143,6 +152,9 @@ class CatalogSplitProvider(SplitProvider):
def predicate(self):
return self._predicate
+ def limit(self) -> Optional[int]:
+ return self._limit
+
def display_name(self) -> str:
return self._table_identifier
@@ -155,11 +167,13 @@ class PreResolvedSplitProvider(SplitProvider):
skipped.
"""
- def __init__(self, table, splits: List[Split], read_type, predicate=None):
+ def __init__(self, table, splits: List[Split], read_type, predicate=None,
+ limit: Optional[int] = None):
self._table = table
self._splits = splits
self._read_type = read_type
self._predicate = predicate
+ self._limit = limit
def table(self):
return self._table
@@ -173,6 +187,9 @@ class PreResolvedSplitProvider(SplitProvider):
def predicate(self):
return self._predicate
+ def limit(self) -> Optional[int]:
+ return self._limit
+
def display_name(self) -> str:
identifier = self._table.identifier
if hasattr(identifier, 'get_full_name'):
diff --git a/paimon-python/pypaimon/read/read_builder.py
b/paimon-python/pypaimon/read/read_builder.py
index 3bb66b7602..13a951df0b 100644
--- a/paimon-python/pypaimon/read/read_builder.py
+++ b/paimon-python/pypaimon/read/read_builder.py
@@ -79,6 +79,7 @@ class ReadBuilder:
predicate=self._predicate,
read_type=self.read_type(),
nested_name_paths=self._nested_name_paths(),
+ limit=self._limit,
)
def _nested_name_paths(self) -> Optional[List[List[str]]]:
diff --git a/paimon-python/pypaimon/read/reader/limited_record_reader.py
b/paimon-python/pypaimon/read/reader/limited_record_reader.py
new file mode 100644
index 0000000000..74f2612ebd
--- /dev/null
+++ b/paimon-python/pypaimon/read/reader/limited_record_reader.py
@@ -0,0 +1,70 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+"""Row-level limit wrapper for any ``RecordReader`` chain.
+
+Currently used at the outermost stage of the PK merge-on-read pipeline
+so the merge output is short-circuited at the row level instead of
+running to completion, but the wrapper itself is generic and may be
+reused on other reader chains.
+"""
+
+from typing import Optional
+
+from pypaimon.read.reader.iface.record_iterator import RecordIterator
+from pypaimon.read.reader.iface.record_reader import RecordReader
+
+
+class LimitedRecordReader(RecordReader):
+ """Stop emitting rows once ``limit`` rows have been delivered."""
+
+ def __init__(self, inner: RecordReader, limit: int):
+ if limit < 0:
+ raise ValueError("limit must be non-negative, got %d" % limit)
+ self._inner = inner
+ self._limit = limit
+ # Public so the iterator can read/write the shared counter without
+ # going through accessor calls per row.
+ self.count = 0
+
+ def read_batch(self) -> Optional[RecordIterator]:
+ if self.count >= self._limit:
+ return None
+ batch = self._inner.read_batch()
+ if batch is None:
+ return None
+ return _LimitedRecordIterator(batch, self)
+
+ def close(self) -> None:
+ self._inner.close()
+
+
+class _LimitedRecordIterator(RecordIterator):
+
+ def __init__(self, inner: RecordIterator, limiter: LimitedRecordReader):
+ self._inner = inner
+ self._limiter = limiter
+
+ def next(self):
+ if self._limiter.count >= self._limiter._limit:
+ return None
+ row = self._inner.next()
+ if row is None:
+ return None
+ self._limiter.count += 1
+ return row
diff --git a/paimon-python/pypaimon/read/split_read.py
b/paimon-python/pypaimon/read/split_read.py
index ed3689b14b..8a203c9f4c 100644
--- a/paimon-python/pypaimon/read/split_read.py
+++ b/paimon-python/pypaimon/read/split_read.py
@@ -577,7 +577,8 @@ class MergeFileSplitRead(SplitRead):
read_type: List[DataField],
split: Split,
row_tracking_enabled: bool,
- outer_extract_name_paths: Optional[List[List[str]]] = None):
+ outer_extract_name_paths: Optional[List[List[str]]] = None,
+ limit: Optional[int] = None):
# Merge functions need full ROW sub-structures, so nested paths
# are not pushed down here; sub-path extraction happens above
# the merge via OuterProjectionRecordReader.
@@ -590,6 +591,7 @@ class MergeFileSplitRead(SplitRead):
nested_name_paths=None,
)
self.outer_extract_name_paths = outer_extract_name_paths
+ self.limit = limit
def kv_reader_supplier(self, file: DataFileMeta, dv_factory:
Optional[Callable] = None) -> RecordReader:
file_batch_reader = self.file_reader_supplier(file, True,
self._get_final_read_data_fields(), False)
@@ -631,6 +633,10 @@ class MergeFileSplitRead(SplitRead):
inner_top_names = [f.name for f in
self.read_fields[-self.value_arity:]]
reader = OuterProjectionRecordReader(
reader, inner_top_names, self.outer_extract_name_paths)
+ if self.limit is not None:
+ from pypaimon.read.reader.limited_record_reader import \
+ LimitedRecordReader
+ reader = LimitedRecordReader(reader, self.limit)
return reader
def _get_all_data_fields(self):
diff --git a/paimon-python/pypaimon/read/table_read.py
b/paimon-python/pypaimon/read/table_read.py
index 1f22d4354f..c45f2a8532 100644
--- a/paimon-python/pypaimon/read/table_read.py
+++ b/paimon-python/pypaimon/read/table_read.py
@@ -42,6 +42,7 @@ class TableRead:
read_type: List[DataField],
include_row_kind: bool = False,
nested_name_paths: Optional[List[List[str]]] = None,
+ limit: Optional[int] = None,
):
from pypaimon.table.file_store_table import FileStoreTable
@@ -50,14 +51,24 @@ class TableRead:
self.read_type = read_type
self.include_row_kind = include_row_kind
self.nested_name_paths = nested_name_paths
+ self.limit = limit
def to_iterator(self, splits: List[Split]) -> Iterator:
+ limit = self.limit
+
def _record_generator():
+ count = 0
for split in splits:
+ if limit is not None and count >= limit:
+ return
reader = self._create_split_read(split).create_reader()
try:
for batch in iter(reader.read_batch, None):
- yield from iter(batch.next, None)
+ for row in iter(batch.next, None):
+ yield row
+ count += 1
+ if limit is not None and count >= limit:
+ return
finally:
reader.close()
@@ -113,21 +124,34 @@ class TableRead:
def _arrow_batch_generator(self, splits: List[Split], schema:
pyarrow.Schema) -> Iterator[pyarrow.RecordBatch]:
chunk_size = 65536
+ # ``remaining`` tracks how many rows we are still allowed to emit
+ # across all splits. ``None`` means unlimited.
+ remaining = self.limit
for split in splits:
+ if remaining is not None and remaining <= 0:
+ break
reader = self._create_split_read(split).create_reader()
try:
if isinstance(reader, RecordBatchReader):
- # Add row kind column if requested (default to +I for
RecordBatchReader)
- if self.include_row_kind:
- for batch in iter(reader.read_arrow_batch, None):
- yield self._add_row_kind_column_to_batch(batch,
"+I")
- else:
- yield from iter(reader.read_arrow_batch, None)
+ for batch in iter(reader.read_arrow_batch, None):
+ if remaining is not None and batch.num_rows >
remaining:
+ batch = batch.slice(0, remaining)
+ if self.include_row_kind:
+ batch = self._add_row_kind_column_to_batch(batch,
"+I")
+ yield batch
+ if remaining is not None:
+ remaining -= batch.num_rows
+ if remaining <= 0:
+ break
else:
row_tuple_chunk = []
row_kind_chunk = []
- for row_iterator in iter(reader.read_batch, None):
+ while True:
+ row_iterator = reader.read_batch()
+ if row_iterator is None:
+ break
+ stop = False
for row in iter(row_iterator.next, None):
if not isinstance(row, OffsetRow):
raise TypeError(f"Expected OffsetRow, but got
{type(row).__name__}")
@@ -135,6 +159,12 @@ class TableRead:
if self.include_row_kind:
row_kind_chunk.append(row.get_row_kind().to_string())
+ if remaining is not None:
+ remaining -= 1
+ if remaining <= 0:
+ stop = True
+ break
+
if len(row_tuple_chunk) >= chunk_size:
batch =
self._convert_rows_to_arrow_batch_with_row_kind(
row_tuple_chunk, row_kind_chunk, schema
@@ -142,6 +172,8 @@ class TableRead:
yield batch
row_tuple_chunk = []
row_kind_chunk = []
+ if stop:
+ break
if row_tuple_chunk:
batch =
self._convert_rows_to_arrow_batch_with_row_kind(
@@ -242,15 +274,22 @@ class TableRead:
splits=splits,
read_type=self.read_type,
predicate=self.predicate,
+ limit=self.limit,
)
)
- return ray.data.read_datasource(
+ ds = ray.data.read_datasource(
datasource,
ray_remote_args=ray_remote_args,
concurrency=concurrency,
override_num_blocks=override_num_blocks,
**read_args
)
+ # Each Ray worker applies the per-task limit independently, so N
+ # workers can collectively yield up to N * limit rows. Cap the
+ # final dataset to the user-visible limit on top.
+ if self.limit is not None:
+ ds = ds.limit(self.limit)
+ return ds
def to_torch(
self,
@@ -285,6 +324,7 @@ class TableRead:
split=split,
row_tracking_enabled=False,
outer_extract_name_paths=outer_extract_name_paths,
+ limit=self.limit,
)
elif self.table.options.data_evolution_enabled():
if self.nested_name_paths and any(
diff --git a/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py
b/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py
index fcbcf558ae..a0a3ad37e8 100644
--- a/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py
+++ b/paimon-python/pypaimon/tests/py36/rest_ao_read_write_test.py
@@ -613,11 +613,12 @@ class RESTAOReadWritePy36Test(RESTBaseTest):
table = self.rest_catalog.get_table('default.test_append_only_limit')
self._write_test_table(table)
+ # Row-level limit: the reader stops at exactly N rows (not "first
+ # split's full row count"). Scan still keeps the first split that
+ # covers the limit; the reader short-circuits inside it.
read_builder = table.new_read_builder().with_limit(1)
actual = self._read_test_table(read_builder)
- # only records from 1st commit (1st split) will be read
- # might be split of "dt=1" or split of "dt=2"
- self.assertEqual(actual.num_rows, 4)
+ self.assertEqual(actual.num_rows, 1)
def test_write_wrong_schema(self):
self.rest_catalog.create_table('default.test_wrong_schema',
diff --git a/paimon-python/pypaimon/tests/reader_append_only_test.py
b/paimon-python/pypaimon/tests/reader_append_only_test.py
index 8006d122d9..5cb8a2dfca 100644
--- a/paimon-python/pypaimon/tests/reader_append_only_test.py
+++ b/paimon-python/pypaimon/tests/reader_append_only_test.py
@@ -658,11 +658,12 @@ class AoReaderTest(unittest.TestCase):
table = self.catalog.get_table('default.test_append_only_limit')
self._write_test_table(table)
+ # Row-level limit: the reader stops at exactly N rows (not "first
+ # split's full row count"). Scan still keeps the first split that
+ # covers the limit; the reader short-circuits inside it.
read_builder = table.new_read_builder().with_limit(1)
actual = self._read_test_table(read_builder)
- # only records from 1st commit (1st split) will be read
- # might be split of "dt=1" or split of "dt=2"
- self.assertEqual(actual.num_rows, 4)
+ self.assertEqual(actual.num_rows, 1)
def test_incremental_timestamp(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
diff --git a/paimon-python/pypaimon/tests/reader_primary_key_test.py
b/paimon-python/pypaimon/tests/reader_primary_key_test.py
index adc8c9cf9b..89a9e8a6a1 100644
--- a/paimon-python/pypaimon/tests/reader_primary_key_test.py
+++ b/paimon-python/pypaimon/tests/reader_primary_key_test.py
@@ -350,7 +350,6 @@ class PkReaderTest(unittest.TestCase):
len(merge_splits), 0,
"Should have at least one merge split to test limit with merge
scenario")
- total_unique_rows = 125
for limit in [5, 10, 20, 50]:
read_builder = table.new_read_builder().with_limit(limit)
table_read = read_builder.new_read()
@@ -362,9 +361,9 @@ class PkReaderTest(unittest.TestCase):
result = table_read.to_arrow(splits)
row_count = result.num_rows if result is not None else 0
self.assertEqual(
- row_count, total_unique_rows,
- f"with_limit({limit}) should return all rows for PK table "
- f"(read-level limit not yet implemented)")
+ row_count, limit,
+ f"with_limit({limit}) on PK table must return exactly "
+ f"{limit} rows now that the read-level limit is wired")
def test_incremental_timestamp(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
diff --git a/paimon-python/pypaimon/tests/rest/rest_read_write_test.py
b/paimon-python/pypaimon/tests/rest/rest_read_write_test.py
index bd357f907a..f46ce5444b 100644
--- a/paimon-python/pypaimon/tests/rest/rest_read_write_test.py
+++ b/paimon-python/pypaimon/tests/rest/rest_read_write_test.py
@@ -302,11 +302,12 @@ class RESTTableReadWriteTest(RESTBaseTest):
table = self.rest_catalog.get_table('default.test_append_only_limit')
self._write_test_table(table)
+ # Row-level limit: the reader stops at exactly N rows (not "first
+ # split's full row count"). Scan still keeps the first split that
+ # covers the limit; the reader short-circuits inside it.
read_builder = table.new_read_builder().with_limit(1)
actual = self._read_test_table(read_builder)
- # only records from 1st commit (1st split) will be read
- # might be split of "dt=1" or split of "dt=2"
- self.assertEqual(actual.num_rows, 4)
+ self.assertEqual(actual.num_rows, 1)
def test_pk_parquet_reader(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
diff --git a/paimon-python/pypaimon/tests/test_limit_pushdown.py
b/paimon-python/pypaimon/tests/test_limit_pushdown.py
new file mode 100644
index 0000000000..2e717c28c7
--- /dev/null
+++ b/paimon-python/pypaimon/tests/test_limit_pushdown.py
@@ -0,0 +1,212 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+"""End-to-end coverage for ``with_limit`` after row-level pushdown.
+
+Locks the contract: ``with_limit(N)`` returns at most ``N`` rows, and
+the reader actually stops at that boundary instead of reading every
+split / merge output to completion and trimming at the consumer.
+"""
+
+import os
+import shutil
+import tempfile
+import unittest
+
+import pyarrow as pa
+
+from pypaimon import CatalogFactory, Schema
+
+
+class LimitPushdownTest(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog_options = {'warehouse': cls.warehouse}
+ cls.catalog = CatalogFactory.create(cls.catalog_options)
+ cls.catalog.create_database('default', False)
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.tempdir, ignore_errors=True)
+
+ @staticmethod
+ def _ao_schema() -> pa.Schema:
+ return pa.schema([
+ pa.field('id', pa.int64(), nullable=False),
+ ('val', pa.int64()),
+ pa.field('dt', pa.string(), nullable=False),
+ ])
+
+ @staticmethod
+ def _pk_schema() -> pa.Schema:
+ return pa.schema([
+ pa.field('id', pa.int64(), nullable=False),
+ ('val', pa.int64()),
+ ])
+
+ def _create_ao_table(self, name: str):
+ identifier = 'default.' + name
+ schema = Schema.from_pyarrow_schema(
+ self._ao_schema(),
+ partition_keys=['dt'],
+ options={'file.format': 'parquet'},
+ )
+ self.catalog.create_table(identifier, schema, False)
+ return self.catalog.get_table(identifier)
+
+ def _create_pk_table(self, name: str, *, num_buckets: int = 1):
+ identifier = 'default.' + name
+ schema = Schema.from_pyarrow_schema(
+ self._pk_schema(),
+ primary_keys=['id'],
+ options={'bucket': str(num_buckets), 'file.format': 'parquet'},
+ )
+ self.catalog.create_table(identifier, schema, False)
+ return self.catalog.get_table(identifier)
+
+ def _write_ao_partitions(self, table, partitions):
+ for dt, rows in partitions:
+ wb = table.new_batch_write_builder()
+ w = wb.new_write()
+ data = pa.Table.from_pylist(
+ [{'id': r, 'val': r * 10, 'dt': dt} for r in rows],
+ schema=self._ao_schema())
+ w.write_arrow(data)
+ wb.new_commit().commit(w.prepare_commit())
+ w.close()
+
+ def _write_pk_snapshots(self, table, snapshots):
+ for rows in snapshots:
+ wb = table.new_batch_write_builder()
+ w = wb.new_write()
+ data = pa.Table.from_pylist(
+ [{'id': i, 'val': v} for i, v in rows],
schema=self._pk_schema())
+ w.write_arrow(data)
+ wb.new_commit().commit(w.prepare_commit())
+ w.close()
+
+ # ---- append-only -----------------------------------------------------
+
+ def test_append_only_limit_stops_within_first_split(self):
+ """With limit=3 on a partitioned append-only table, the result is
+ exactly 3 rows — even though each partition split has 5 rows."""
+ table = self._create_ao_table('limit_ao_within_split')
+ self._write_ao_partitions(table, [
+ ('p1', list(range(5))), # 5 rows
+ ('p2', list(range(5, 10))), # 5 rows
+ ])
+ rb = table.new_read_builder().with_limit(3)
+ result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+ self.assertEqual(result.num_rows, 3)
+
+ def test_append_only_limit_spans_multiple_splits(self):
+ """Limit larger than first split: read carries over to the next
+ split until the budget is met."""
+ table = self._create_ao_table('limit_ao_span_splits')
+ self._write_ao_partitions(table, [
+ ('p1', [1, 2]),
+ ('p2', [3, 4]),
+ ('p3', [5, 6]),
+ ])
+ rb = table.new_read_builder().with_limit(5)
+ result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+ self.assertEqual(result.num_rows, 5)
+
+ def test_append_only_limit_zero_returns_empty(self):
+ table = self._create_ao_table('limit_ao_zero')
+ self._write_ao_partitions(table, [('p1', [1, 2, 3])])
+ rb = table.new_read_builder().with_limit(0)
+ splits = rb.new_scan().plan().splits()
+ result = rb.new_read().to_arrow(splits)
+ self.assertEqual(result.num_rows, 0)
+
+ def test_append_only_limit_larger_than_total(self):
+ """Limit greater than the total returns the total, not the limit."""
+ table = self._create_ao_table('limit_ao_oversize')
+ self._write_ao_partitions(table, [('p1', [1, 2, 3])])
+ rb = table.new_read_builder().with_limit(100)
+ result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+ self.assertEqual(result.num_rows, 3)
+
+ # ---- PK merge-on-read ------------------------------------------------
+
+ def test_pk_merge_limit_stops_within_first_split(self):
+ """PK + multiple snapshots forces the merge-read path. The reader
+ must stop at limit rows instead of running every section to
+ completion and trimming at the consumer."""
+ table = self._create_pk_table('limit_pk_within_split')
+ # Two snapshots over the same key range → merge path; total
+ # post-merge unique rows = 20.
+ self._write_pk_snapshots(table, [
+ [(i, i) for i in range(20)],
+ [(i, i + 1000) for i in range(0, 20, 2)],
+ ])
+ for limit in (1, 5, 10, 19):
+ rb = table.new_read_builder().with_limit(limit)
+ result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+ self.assertEqual(
+ result.num_rows, limit,
+ "with_limit(%d) must short-circuit at the row level" % limit)
+
+ def test_pk_merge_limit_equals_total(self):
+ """Limit equal to total post-merge row count: returns everything."""
+ table = self._create_pk_table('limit_pk_equals_total')
+ self._write_pk_snapshots(table, [
+ [(i, i) for i in range(10)],
+ [(i, i + 100) for i in range(5)],
+ ])
+ rb = table.new_read_builder().with_limit(10)
+ result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+ self.assertEqual(result.num_rows, 10)
+
+ def test_pk_merge_limit_with_predicate(self):
+ """``with_limit`` plus ``with_filter``: the filter prunes first and
+ the limit caps what survives. ``val >= 1000`` matches the latest
+ write of the even ``id`` rows; limit then takes the prefix."""
+ table = self._create_pk_table('limit_pk_with_filter')
+ self._write_pk_snapshots(table, [
+ [(i, i) for i in range(20)],
+ [(i, i + 1000) for i in range(0, 20, 2)], # update evens
+ ])
+ rb = table.new_read_builder()
+ pred = rb.new_predicate_builder().greater_or_equal('val', 1000)
+ rb = rb.with_filter(pred).with_limit(3)
+ result = rb.new_read().to_arrow(rb.new_scan().plan().splits())
+ self.assertEqual(result.num_rows, 3)
+ for v in result.column('val').to_pylist():
+ self.assertGreaterEqual(v, 1000)
+
+ # ---- to_iterator path ------------------------------------------------
+
+ def test_to_iterator_limit_short_circuits(self):
+ table = self._create_ao_table('limit_iter')
+ self._write_ao_partitions(table, [
+ ('p1', list(range(50))),
+ ('p2', list(range(50, 100))),
+ ])
+ rb = table.new_read_builder().with_limit(7)
+ it = rb.new_read().to_iterator(rb.new_scan().plan().splits())
+ rows = list(it)
+ self.assertEqual(len(rows), 7)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paimon-python/pypaimon/tests/test_limited_record_reader.py
b/paimon-python/pypaimon/tests/test_limited_record_reader.py
new file mode 100644
index 0000000000..edbc1b75f4
--- /dev/null
+++ b/paimon-python/pypaimon/tests/test_limited_record_reader.py
@@ -0,0 +1,141 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import unittest
+from typing import List, Optional
+
+from pypaimon.read.reader.iface.record_iterator import RecordIterator
+from pypaimon.read.reader.iface.record_reader import RecordReader
+from pypaimon.read.reader.limited_record_reader import LimitedRecordReader
+
+
+class _ListIterator(RecordIterator):
+ def __init__(self, items: List):
+ self._items = items
+ self._idx = 0
+
+ def next(self):
+ if self._idx >= len(self._items):
+ return None
+ v = self._items[self._idx]
+ self._idx += 1
+ return v
+
+
+class _StaticReader(RecordReader):
+ """Hands back batches one at a time, tracks close calls and the
+ number of times ``read_batch`` was invoked (so tests can prove the
+ limiter actually short-circuited instead of draining the inner)."""
+
+ def __init__(self, batches: List[List]):
+ self._batches = batches
+ self._idx = 0
+ self.closed = False
+ self.read_batch_calls = 0
+
+ def read_batch(self) -> Optional[RecordIterator]:
+ self.read_batch_calls += 1
+ if self._idx >= len(self._batches):
+ return None
+ batch = self._batches[self._idx]
+ self._idx += 1
+ return _ListIterator(batch)
+
+ def close(self):
+ self.closed = True
+
+
+def _drain(reader: RecordReader) -> List:
+ out = []
+ while True:
+ batch = reader.read_batch()
+ if batch is None:
+ break
+ while True:
+ v = batch.next()
+ if v is None:
+ break
+ out.append(v)
+ return out
+
+
+class LimitedRecordReaderTest(unittest.TestCase):
+
+ def test_limit_within_first_batch(self):
+ reader = LimitedRecordReader(
+ _StaticReader([[1, 2, 3, 4, 5]]), limit=3)
+ self.assertEqual(_drain(reader), [1, 2, 3])
+
+ def test_limit_spans_multiple_batches(self):
+ reader = LimitedRecordReader(
+ _StaticReader([[1, 2], [3, 4], [5, 6]]), limit=5)
+ self.assertEqual(_drain(reader), [1, 2, 3, 4, 5])
+
+ def test_limit_larger_than_total_returns_everything(self):
+ reader = LimitedRecordReader(
+ _StaticReader([[1, 2, 3]]), limit=999)
+ self.assertEqual(_drain(reader), [1, 2, 3])
+
+ def test_limit_zero_returns_nothing(self):
+ reader = LimitedRecordReader(
+ _StaticReader([[1, 2, 3]]), limit=0)
+ self.assertEqual(_drain(reader), [])
+ # read_batch should short-circuit immediately rather than peek.
+ self.assertIsNone(reader.read_batch())
+
+ def test_negative_limit_rejected(self):
+ with self.assertRaises(ValueError):
+ LimitedRecordReader(_StaticReader([]), limit=-1)
+
+ def test_close_propagates(self):
+ inner = _StaticReader([[1, 2]])
+ reader = LimitedRecordReader(inner, limit=10)
+ reader.close()
+ self.assertTrue(inner.closed)
+
+ def test_iterator_stops_mid_batch(self):
+ # Limit cuts halfway through a batch; the next() call past the limit
+ # must return None even though the inner batch still has items.
+ reader = LimitedRecordReader(
+ _StaticReader([[1, 2, 3, 4, 5]]), limit=2)
+ batch = reader.read_batch()
+ self.assertEqual(batch.next(), 1)
+ self.assertEqual(batch.next(), 2)
+ self.assertIsNone(batch.next())
+ # Subsequent read_batch is also None.
+ self.assertIsNone(reader.read_batch())
+
+ def test_count_visible_for_observability(self):
+ reader = LimitedRecordReader(
+ _StaticReader([[1, 2, 3, 4]]), limit=10)
+ _drain(reader)
+ self.assertEqual(reader.count, 4)
+
+ def test_does_not_drain_inner_when_limit_met_within_first_batch(self):
+ """Direct proof of the short-circuit: once the limiter has handed
+ out ``limit`` rows the next ``read_batch`` short-circuits at the
+ entry guard and never pulls a second batch from the inner."""
+ inner = _StaticReader([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]])
+ reader = LimitedRecordReader(inner, limit=3)
+ self.assertEqual(_drain(reader), [1, 2, 3])
+ # Only the first batch was fetched; the second is never asked for.
+ self.assertEqual(inner.read_batch_calls, 1)
+
+
+if __name__ == '__main__':
+ unittest.main()