This is an automated email from the ASF dual-hosted git repository.
kevinjqliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 5773b7f1 perf: iterate over generators when writing datafiles to
reduce memory pressure (#2671)
5773b7f1 is described below
commit 5773b7f1bf2081a90a490f9d670eef804eb88ab4
Author: Alex <[email protected]>
AuthorDate: Mon Nov 3 11:46:49 2025 -0700
perf: iterate over generators when writing datafiles to reduce memory
pressure (#2671)
# Rationale for this change
When writing to partitioned tables, there is a large memory spike when
the partitions are computed because we `.combine_chunks()` on the new
partitioned arrow tables and we materialize the entire list of
partitions before writing data files.
This PR switches the partition computation to a generator to avoid
materializing all the partitions in memory at once, reducing the memory
overhead of writing to partitioned tables.
## Are these changes tested?
No new tests. The tests using this method were updated to consume the
generator as a list.
However, in my personal use case, I am using
`pa.total_allocated_bytes()` to determine memory allocation before and
after the write and see the following across 5 writes of ~128 MB:
| Run | Original Impl (Before Write) | Original Impl (After Write) |
Iters (Before Write) | Iters (After Write) |
|---|---|---|---|---|
| 1 | 29.31 MB | 151.62 MB | 28.38 MB | 30.40 MB |
| 2 | 27.74 MB | 151.62 MB | 28.85 MB | 30.36 MB |
| 3 | 28.81 MB | 151.62 MB | 28.52 MB | 31.29 MB |
| 4 | 28.71 MB | 151.62 MB | 29.27 MB | 30.64 MB |
| 5 | 28.60 MB | 151.61 MB | 28.29 MB | 31.11 MB |
This scales with the size of the write: if I want to write a 3 GB arrow
table to a partitioned table, I need at least 6 GB RAM.
## Are there any user-facing changes?
No.
---
pyiceberg/io/pyarrow.py | 41 +++++++++++++++++------------------------
tests/io/test_pyarrow.py | 8 ++++----
2 files changed, 21 insertions(+), 28 deletions(-)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index e42c1307..7710df76 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -2790,11 +2790,9 @@ def _dataframe_to_data_files(
yield from write_file(
io=io,
table_metadata=table_metadata,
- tasks=iter(
- [
- WriteTask(write_uuid=write_uuid, task_id=next(counter),
record_batches=batches, schema=task_schema)
- for batches in bin_pack_arrow_table(df, target_file_size)
- ]
+ tasks=(
+ WriteTask(write_uuid=write_uuid, task_id=next(counter),
record_batches=batches, schema=task_schema)
+ for batches in bin_pack_arrow_table(df, target_file_size)
),
)
else:
@@ -2802,18 +2800,16 @@ def _dataframe_to_data_files(
yield from write_file(
io=io,
table_metadata=table_metadata,
- tasks=iter(
- [
- WriteTask(
- write_uuid=write_uuid,
- task_id=next(counter),
- record_batches=batches,
- partition_key=partition.partition_key,
- schema=task_schema,
- )
- for partition in partitions
- for batches in
bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
- ]
+ tasks=(
+ WriteTask(
+ write_uuid=write_uuid,
+ task_id=next(counter),
+ record_batches=batches,
+ partition_key=partition.partition_key,
+ schema=task_schema,
+ )
+ for partition in partitions
+ for batches in
bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
),
)
@@ -2824,7 +2820,7 @@ class _TablePartition:
arrow_table_partition: pa.Table
-def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table:
pa.Table) -> List[_TablePartition]:
+def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table:
pa.Table) -> Iterable[_TablePartition]:
"""Based on the iceberg table partition spec, filter the arrow table into
partitions with their keys.
Example:
@@ -2852,8 +2848,6 @@ def _determine_partitions(spec: PartitionSpec, schema:
Schema, arrow_table: pa.T
unique_partition_fields =
arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])
- table_partitions = []
- # TODO: As a next step, we could also play around with yielding instead of
materializing the full list
for unique_partition in unique_partition_fields.to_pylist():
partition_key = PartitionKey(
field_values=[
@@ -2880,12 +2874,11 @@ def _determine_partitions(spec: PartitionSpec, schema:
Schema, arrow_table: pa.T
# The combine_chunks seems to be counter-intuitive to do, but it
actually returns
# fresh buffers that don't interfere with each other when it is
written out to file
- table_partitions.append(
- _TablePartition(partition_key=partition_key,
arrow_table_partition=filtered_table.combine_chunks())
+ yield _TablePartition(
+ partition_key=partition_key,
+ arrow_table_partition=filtered_table.combine_chunks(),
)
- return table_partitions
-
def _get_field_from_arrow_table(arrow_table: pa.Table, field_path: str) ->
pa.Array:
"""Get a field from an Arrow table, supporting both literal field names
and nested field paths.
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index a19ddd60..45b9d9c9 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -2479,7 +2479,7 @@ def test_partition_for_demo() -> None:
PartitionField(source_id=2, field_id=1002,
transform=IdentityTransform(), name="n_legs_identity"),
PartitionField(source_id=1, field_id=1001,
transform=IdentityTransform(), name="year_identity"),
)
- result = _determine_partitions(partition_spec, test_schema, arrow_table)
+ result = list(_determine_partitions(partition_spec, test_schema,
arrow_table))
assert {table_partition.partition_key.partition for table_partition in
result} == {
Record(2, 2020),
Record(100, 2021),
@@ -2518,7 +2518,7 @@ def test_partition_for_nested_field() -> None:
]
arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
- partitions = _determine_partitions(spec, schema, arrow_table)
+ partitions = list(_determine_partitions(spec, schema, arrow_table))
partition_values = {p.partition_key.partition[0] for p in partitions}
assert partition_values == {486729, 486730}
@@ -2550,7 +2550,7 @@ def test_partition_for_deep_nested_field() -> None:
]
arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
- partitions = _determine_partitions(spec, schema, arrow_table)
+ partitions = list(_determine_partitions(spec, schema, arrow_table))
assert len(partitions) == 2 # 2 unique partitions
partition_values = {p.partition_key.partition[0] for p in partitions}
@@ -2621,7 +2621,7 @@ def test_identity_partition_on_multi_columns() -> None:
}
arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
- result = _determine_partitions(partition_spec, test_schema,
arrow_table)
+ result = list(_determine_partitions(partition_spec, test_schema,
arrow_table))
assert {table_partition.partition_key.partition for table_partition in
result} == expected
concatenated_arrow_table =
pa.concat_tables([table_partition.arrow_table_partition for table_partition in
result])