This is an automated email from the ASF dual-hosted git repository. sungwy pushed a commit to branch pyiceberg-0.7.x in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
commit e00a55c80412affd92f7eb548238ce1d3ad822a7 Author: Sung Yun <[email protected]> AuthorDate: Fri Aug 9 04:32:17 2024 -0400 Handle Empty `RecordBatch` within `_task_to_record_batches` (#1026) --- pyiceberg/io/pyarrow.py | 7 ++-- tests/integration/test_deletes.py | 67 +++++++++++++++++++++++++++++++++++++++ tests/integration/test_reads.py | 28 ++++++++++++++++ 3 files changed, 100 insertions(+), 2 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 33561da5..6c5db515 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1205,9 +1205,11 @@ def _task_to_record_batches( columns=[col.name for col in file_project_schema.columns], ) - current_index = 0 + next_index = 0 batches = fragment_scanner.to_batches() for batch in batches: + next_index = next_index + len(batch) + current_index = next_index - len(batch) if positional_deletes: # Create the mask of indices that we're interested in indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch)) @@ -1219,9 +1221,10 @@ def _task_to_record_batches( # https://github.com/apache/arrow/issues/39220 arrow_table = pa.Table.from_batches([batch]) arrow_table = arrow_table.filter(pyarrow_filter) + if len(arrow_table) == 0: + continue batch = arrow_table.to_batches()[0] yield _to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True) - current_index += len(batch) def _task_to_table( diff --git a/tests/integration/test_deletes.py b/tests/integration/test_deletes.py index 4bddf09b..be02696a 100644 --- a/tests/integration/test_deletes.py +++ b/tests/integration/test_deletes.py @@ -222,6 +222,73 @@ def test_delete_partitioned_table_positional_deletes(spark: SparkSession, sessio assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [10], "number": [20]} [email protected] +def test_delete_partitioned_table_positional_deletes_empty_batch(spark: SparkSession, session_catalog: RestCatalog) -> None: + identifier = "default.test_delete_partitioned_table_positional_deletes_empty_batch" + + run_spark_commands( + spark, + [ + f"DROP TABLE IF EXISTS {identifier}", + f""" + CREATE TABLE {identifier} ( + number_partitioned int, + number int + ) + USING iceberg + PARTITIONED BY (number_partitioned) + TBLPROPERTIES( + 'format-version' = 2, + 'write.delete.mode'='merge-on-read', + 'write.update.mode'='merge-on-read', + 'write.merge.mode'='merge-on-read', + 'write.parquet.row-group-limit'=1 + ) + """, + ], + ) + + tbl = session_catalog.load_table(identifier) + + arrow_table = pa.Table.from_arrays( + [ + pa.array([10, 10, 10]), + pa.array([1, 2, 3]), + ], + schema=pa.schema([pa.field("number_partitioned", pa.int32()), pa.field("number", pa.int32())]), + ) + + tbl.append(arrow_table) + + assert len(tbl.scan().to_arrow()) == 3 + + run_spark_commands( + spark, + [ + # Generate a positional delete + f""" + DELETE FROM {identifier} WHERE number = 1 + """, + ], + ) + # Assert that there is just a single Parquet file, that has one merge on read file + tbl = tbl.refresh() + + files = list(tbl.scan().plan_files()) + assert len(files) == 1 + assert len(files[0].delete_files) == 1 + + assert len(tbl.scan().to_arrow()) == 2 + + assert len(tbl.scan(row_filter="number_partitioned == 10").to_arrow()) == 2 + + assert len(tbl.scan(row_filter="number_partitioned == 1").to_arrow()) == 0 + + reader = tbl.scan(row_filter="number_partitioned == 1").to_arrow_batch_reader() + assert isinstance(reader, pa.RecordBatchReader) + assert len(reader.read_all()) == 0 + + @pytest.mark.integration def test_overwrite_partitioned_table(spark: SparkSession, session_catalog: RestCatalog) -> None: identifier = "default.table_partitioned_delete" diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index 699cdb17..82712675 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -707,3 +707,31 @@ def test_empty_scan_ordered_str(catalog: Catalog) -> None: table_empty_scan_ordered_str = catalog.load_table("default.test_empty_scan_ordered_str") arrow_table = table_empty_scan_ordered_str.scan(EqualTo("id", "b")).to_arrow() assert len(arrow_table) == 0 + + [email protected] [email protected]("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_table_scan_empty_table(catalog: Catalog) -> None: + identifier = "default.test_table_scan_empty_table" + arrow_table = pa.Table.from_arrays( + [ + pa.array([]), + ], + schema=pa.schema([pa.field("colA", pa.string())]), + ) + + try: + catalog.drop_table(identifier) + except NoSuchTableError: + pass + + tbl = catalog.create_table( + identifier, + schema=arrow_table.schema, + ) + + tbl.append(arrow_table) + + result_table = tbl.scan().to_arrow() + + assert len(result_table) == 0
