This is an automated email from the ASF dual-hosted git repository.

sungwy pushed a commit to branch pyiceberg-0.7.x
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git

commit f73da80caf6409fc6442d1c530078cabbec58f03
Author: Kevin Liu <[email protected]>
AuthorDate: Mon Aug 12 14:48:21 2024 -0700

    [bug] fix reading with `to_arrow_batch_reader` and `limit` (#1042)
    
    * fix project_batches with limit
    
    * add test
    
    * lint + readability
---
 pyiceberg/io/pyarrow.py         |  8 +++++--
 tests/integration/test_reads.py | 48 +++++++++++++++++++++++++++++++++++++++++
 2 files changed, 54 insertions(+), 2 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 52188459..b8471ee5 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -1409,6 +1409,9 @@ def project_batches(
     total_row_count = 0
 
     for task in tasks:
+        # stop early if limit is satisfied
+        if limit is not None and total_row_count >= limit:
+            break
         batches = _task_to_record_batches(
             fs,
             task,
@@ -1421,9 +1424,10 @@ def project_batches(
         )
         for batch in batches:
             if limit is not None:
-                if total_row_count + len(batch) >= limit:
-                    yield batch.slice(0, limit - total_row_count)
+                if total_row_count >= limit:
                     break
+                elif total_row_count + len(batch) >= limit:
+                    batch = batch.slice(0, limit - total_row_count)
             yield batch
             total_row_count += len(batch)
 
diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py
index 82712675..cee8839d 100644
--- a/tests/integration/test_reads.py
+++ b/tests/integration/test_reads.py
@@ -240,6 +240,54 @@ def test_pyarrow_limit(catalog: Catalog) -> None:
     full_result = table_test_limit.scan(selected_fields=("idx",), 
limit=999).to_arrow()
     assert len(full_result) == 10
 
+    # test `to_arrow_batch_reader`
+    limited_result = table_test_limit.scan(selected_fields=("idx",), 
limit=1).to_arrow_batch_reader().read_all()
+    assert len(limited_result) == 1
+
+    empty_result = table_test_limit.scan(selected_fields=("idx",), 
limit=0).to_arrow_batch_reader().read_all()
+    assert len(empty_result) == 0
+
+    full_result = table_test_limit.scan(selected_fields=("idx",), 
limit=999).to_arrow_batch_reader().read_all()
+    assert len(full_result) == 10
+
+
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_pyarrow_limit_with_multiple_files(catalog: Catalog) -> None:
+    table_name = "default.test_pyarrow_limit_with_multiple_files"
+    try:
+        catalog.drop_table(table_name)
+    except NoSuchTableError:
+        pass
+    reference_table = catalog.load_table("default.test_limit")
+    data = reference_table.scan().to_arrow()
+    table_test_limit = catalog.create_table(table_name, 
schema=reference_table.schema())
+
+    n_files = 2
+    for _ in range(n_files):
+        table_test_limit.append(data)
+    assert len(table_test_limit.inspect.files()) == n_files
+
+    # test with multiple files
+    limited_result = table_test_limit.scan(selected_fields=("idx",), 
limit=1).to_arrow()
+    assert len(limited_result) == 1
+
+    empty_result = table_test_limit.scan(selected_fields=("idx",), 
limit=0).to_arrow()
+    assert len(empty_result) == 0
+
+    full_result = table_test_limit.scan(selected_fields=("idx",), 
limit=999).to_arrow()
+    assert len(full_result) == 10 * n_files
+
+    # test `to_arrow_batch_reader`
+    limited_result = table_test_limit.scan(selected_fields=("idx",), 
limit=1).to_arrow_batch_reader().read_all()
+    assert len(limited_result) == 1
+
+    empty_result = table_test_limit.scan(selected_fields=("idx",), 
limit=0).to_arrow_batch_reader().read_all()
+    assert len(empty_result) == 0
+
+    full_result = table_test_limit.scan(selected_fields=("idx",), 
limit=999).to_arrow_batch_reader().read_all()
+    assert len(full_result) == 10 * n_files
+
 
 @pytest.mark.integration
 @pytest.mark.filterwarnings("ignore")

Reply via email to