This is an automated email from the ASF dual-hosted git repository.

sungwy pushed a commit to branch pyiceberg-0.7.x
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git

commit e00a55c80412affd92f7eb548238ce1d3ad822a7
Author: Sung Yun <[email protected]>
AuthorDate: Fri Aug 9 04:32:17 2024 -0400

    Handle Empty `RecordBatch` within `_task_to_record_batches` (#1026)
---
 pyiceberg/io/pyarrow.py           |  7 ++--
 tests/integration/test_deletes.py | 67 +++++++++++++++++++++++++++++++++++++++
 tests/integration/test_reads.py   | 28 ++++++++++++++++
 3 files changed, 100 insertions(+), 2 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 33561da5..6c5db515 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -1205,9 +1205,11 @@ def _task_to_record_batches(
             columns=[col.name for col in file_project_schema.columns],
         )
 
-        current_index = 0
+        next_index = 0
         batches = fragment_scanner.to_batches()
         for batch in batches:
+            next_index = next_index + len(batch)
+            current_index = next_index - len(batch)
             if positional_deletes:
                 # Create the mask of indices that we're interested in
                 indices = _combine_positional_deletes(positional_deletes, 
current_index, current_index + len(batch))
@@ -1219,9 +1221,10 @@ def _task_to_record_batches(
                     # https://github.com/apache/arrow/issues/39220
                     arrow_table = pa.Table.from_batches([batch])
                     arrow_table = arrow_table.filter(pyarrow_filter)
+                    if len(arrow_table) == 0:
+                        continue
                     batch = arrow_table.to_batches()[0]
             yield _to_requested_schema(projected_schema, file_project_schema, 
batch, downcast_ns_timestamp_to_us=True)
-            current_index += len(batch)
 
 
 def _task_to_table(
diff --git a/tests/integration/test_deletes.py 
b/tests/integration/test_deletes.py
index 4bddf09b..be02696a 100644
--- a/tests/integration/test_deletes.py
+++ b/tests/integration/test_deletes.py
@@ -222,6 +222,73 @@ def 
test_delete_partitioned_table_positional_deletes(spark: SparkSession, sessio
     assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [10], 
"number": [20]}
 
 
[email protected]
+def test_delete_partitioned_table_positional_deletes_empty_batch(spark: 
SparkSession, session_catalog: RestCatalog) -> None:
+    identifier = 
"default.test_delete_partitioned_table_positional_deletes_empty_batch"
+
+    run_spark_commands(
+        spark,
+        [
+            f"DROP TABLE IF EXISTS {identifier}",
+            f"""
+            CREATE TABLE {identifier} (
+                number_partitioned  int,
+                number              int
+            )
+            USING iceberg
+            PARTITIONED BY (number_partitioned)
+            TBLPROPERTIES(
+                'format-version' = 2,
+                'write.delete.mode'='merge-on-read',
+                'write.update.mode'='merge-on-read',
+                'write.merge.mode'='merge-on-read',
+                'write.parquet.row-group-limit'=1
+            )
+        """,
+        ],
+    )
+
+    tbl = session_catalog.load_table(identifier)
+
+    arrow_table = pa.Table.from_arrays(
+        [
+            pa.array([10, 10, 10]),
+            pa.array([1, 2, 3]),
+        ],
+        schema=pa.schema([pa.field("number_partitioned", pa.int32()), 
pa.field("number", pa.int32())]),
+    )
+
+    tbl.append(arrow_table)
+
+    assert len(tbl.scan().to_arrow()) == 3
+
+    run_spark_commands(
+        spark,
+        [
+            # Generate a positional delete
+            f"""
+            DELETE FROM {identifier} WHERE number = 1
+        """,
+        ],
+    )
+    # Assert that there is just a single Parquet file, that has one merge on 
read file
+    tbl = tbl.refresh()
+
+    files = list(tbl.scan().plan_files())
+    assert len(files) == 1
+    assert len(files[0].delete_files) == 1
+
+    assert len(tbl.scan().to_arrow()) == 2
+
+    assert len(tbl.scan(row_filter="number_partitioned == 10").to_arrow()) == 2
+
+    assert len(tbl.scan(row_filter="number_partitioned == 1").to_arrow()) == 0
+
+    reader = tbl.scan(row_filter="number_partitioned == 
1").to_arrow_batch_reader()
+    assert isinstance(reader, pa.RecordBatchReader)
+    assert len(reader.read_all()) == 0
+
+
 @pytest.mark.integration
 def test_overwrite_partitioned_table(spark: SparkSession, session_catalog: 
RestCatalog) -> None:
     identifier = "default.table_partitioned_delete"
diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py
index 699cdb17..82712675 100644
--- a/tests/integration/test_reads.py
+++ b/tests/integration/test_reads.py
@@ -707,3 +707,31 @@ def test_empty_scan_ordered_str(catalog: Catalog) -> None:
     table_empty_scan_ordered_str = 
catalog.load_table("default.test_empty_scan_ordered_str")
     arrow_table = table_empty_scan_ordered_str.scan(EqualTo("id", 
"b")).to_arrow()
     assert len(arrow_table) == 0
+
+
[email protected]
[email protected]("catalog", 
[pytest.lazy_fixture("session_catalog_hive"), 
pytest.lazy_fixture("session_catalog")])
+def test_table_scan_empty_table(catalog: Catalog) -> None:
+    identifier = "default.test_table_scan_empty_table"
+    arrow_table = pa.Table.from_arrays(
+        [
+            pa.array([]),
+        ],
+        schema=pa.schema([pa.field("colA", pa.string())]),
+    )
+
+    try:
+        catalog.drop_table(identifier)
+    except NoSuchTableError:
+        pass
+
+    tbl = catalog.create_table(
+        identifier,
+        schema=arrow_table.schema,
+    )
+
+    tbl.append(arrow_table)
+
+    result_table = tbl.scan().to_arrow()
+
+    assert len(result_table) == 0

Reply via email to