This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f47dfd2 Move determine_partitions and helper methods to io.pyarrow 
(#906)
8f47dfd2 is described below

commit 8f47dfd2a0f586d58aa29e165540706066ea5282
Author: Soumya Ghosh <[email protected]>
AuthorDate: Thu Jul 11 11:52:55 2024 +0530

    Move determine_partitions and helper methods to io.pyarrow (#906)
---
 pyiceberg/io/pyarrow.py     | 101 ++++++++++++++++++++++++++++++++++++++++++--
 pyiceberg/table/__init__.py | 100 -------------------------------------------
 tests/io/test_pyarrow.py    |  84 +++++++++++++++++++++++++++++++++++-
 tests/table/test_init.py    |  84 ------------------------------------
 4 files changed, 180 insertions(+), 189 deletions(-)

diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index ae7799cf..f28fe76b 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -113,7 +113,7 @@ from pyiceberg.manifest import (
     DataFileContent,
     FileFormat,
 )
-from pyiceberg.partitioning import PartitionField, PartitionSpec, 
partition_record_value
+from pyiceberg.partitioning import PartitionField, PartitionFieldValue, 
PartitionKey, PartitionSpec, partition_record_value
 from pyiceberg.schema import (
     PartnerAccessor,
     PreOrderSchemaVisitor,
@@ -2125,8 +2125,6 @@ def _dataframe_to_data_files(
             ]),
         )
     else:
-        from pyiceberg.table import _determine_partitions
-
         partitions = _determine_partitions(spec=table_metadata.spec(), 
schema=table_metadata.schema(), arrow_table=df)
         yield from write_file(
             io=io,
@@ -2143,3 +2141,100 @@ def _dataframe_to_data_files(
                 for batches in 
bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
             ]),
         )
+
+
+@dataclass(frozen=True)
+class _TablePartition:
+    partition_key: PartitionKey
+    arrow_table_partition: pa.Table
+
+
+def _get_table_partitions(
+    arrow_table: pa.Table,
+    partition_spec: PartitionSpec,
+    schema: Schema,
+    slice_instructions: list[dict[str, Any]],
+) -> list[_TablePartition]:
+    sorted_slice_instructions = sorted(slice_instructions, key=lambda x: 
x["offset"])
+
+    partition_fields = partition_spec.fields
+
+    offsets = [inst["offset"] for inst in sorted_slice_instructions]
+    projected_and_filtered = {
+        partition_field.source_id: 
arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
+        .take(offsets)
+        .to_pylist()
+        for partition_field in partition_fields
+    }
+
+    table_partitions = []
+    for idx, inst in enumerate(sorted_slice_instructions):
+        partition_slice = arrow_table.slice(**inst)
+        fieldvalues = [
+            PartitionFieldValue(partition_field, 
projected_and_filtered[partition_field.source_id][idx])
+            for partition_field in partition_fields
+        ]
+        partition_key = PartitionKey(raw_partition_field_values=fieldvalues, 
partition_spec=partition_spec, schema=schema)
+        table_partitions.append(_TablePartition(partition_key=partition_key, 
arrow_table_partition=partition_slice))
+    return table_partitions
+
+
+def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: 
pa.Table) -> List[_TablePartition]:
+    """Based on the iceberg table partition spec, slice the arrow table into 
partitions with their keys.
+
+    Example:
+    Input:
+    An arrow table with partition key of ['n_legs', 'year'] and with data of
+    {'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
+     'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
+     'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", 
"Horse","Brittle stars", "Centipede"]}.
+    The algorithm:
+    Firstly we group the rows into partitions by sorting with sort order 
[('n_legs', 'descending'), ('year', 'descending')]
+    and null_placement of "at_end".
+    This gives the same table as raw input.
+    Then we sort_indices using reverse order of [('n_legs', 'descending'), 
('year', 'descending')]
+    and null_placement : "at_start".
+    This gives:
+    [8, 7, 4, 5, 6, 3, 1, 2, 0]
+    Based on this we get partition groups of indices:
+    [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 
'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, 
{'offset': 0, 'length': 1}]
+    We then retrieve the partition keys by offsets.
+    And slice the arrow table by offsets and lengths of each partition.
+    """
+    partition_columns: List[Tuple[PartitionField, NestedField]] = [
+        (partition_field, schema.find_field(partition_field.source_id)) for 
partition_field in spec.fields
+    ]
+    partition_values_table = pa.table({
+        str(partition.field_id): 
partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
+        for partition, field in partition_columns
+    })
+
+    # Sort by partitions
+    sort_indices = pa.compute.sort_indices(
+        partition_values_table,
+        sort_keys=[(col, "ascending") for col in 
partition_values_table.column_names],
+        null_placement="at_end",
+    ).to_pylist()
+    arrow_table = arrow_table.take(sort_indices)
+
+    # Get slice_instructions to group by partitions
+    partition_values_table = partition_values_table.take(sort_indices)
+    reversed_indices = pa.compute.sort_indices(
+        partition_values_table,
+        sort_keys=[(col, "descending") for col in 
partition_values_table.column_names],
+        null_placement="at_start",
+    ).to_pylist()
+    slice_instructions: List[Dict[str, Any]] = []
+    last = len(reversed_indices)
+    reversed_indices_size = len(reversed_indices)
+    ptr = 0
+    while ptr < reversed_indices_size:
+        group_size = last - reversed_indices[ptr]
+        offset = reversed_indices[ptr]
+        slice_instructions.append({"offset": offset, "length": group_size})
+        last = reversed_indices[ptr]
+        ptr = ptr + group_size
+
+    table_partitions: List[_TablePartition] = 
_get_table_partitions(arrow_table, spec, schema, slice_instructions)
+
+    return table_partitions
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 4080f3a0..76382008 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -92,7 +92,6 @@ from pyiceberg.partitioning import (
     PARTITION_FIELD_ID_START,
     UNPARTITIONED_PARTITION_SPEC,
     PartitionField,
-    PartitionFieldValue,
     PartitionKey,
     PartitionSpec,
     _PartitionNameGenerator,
@@ -4412,105 +4411,6 @@ class InspectTable:
         )
 
 
-@dataclass(frozen=True)
-class TablePartition:
-    partition_key: PartitionKey
-    arrow_table_partition: pa.Table
-
-
-def _get_table_partitions(
-    arrow_table: pa.Table,
-    partition_spec: PartitionSpec,
-    schema: Schema,
-    slice_instructions: list[dict[str, Any]],
-) -> list[TablePartition]:
-    sorted_slice_instructions = sorted(slice_instructions, key=lambda x: 
x["offset"])
-
-    partition_fields = partition_spec.fields
-
-    offsets = [inst["offset"] for inst in sorted_slice_instructions]
-    projected_and_filtered = {
-        partition_field.source_id: 
arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
-        .take(offsets)
-        .to_pylist()
-        for partition_field in partition_fields
-    }
-
-    table_partitions = []
-    for idx, inst in enumerate(sorted_slice_instructions):
-        partition_slice = arrow_table.slice(**inst)
-        fieldvalues = [
-            PartitionFieldValue(partition_field, 
projected_and_filtered[partition_field.source_id][idx])
-            for partition_field in partition_fields
-        ]
-        partition_key = PartitionKey(raw_partition_field_values=fieldvalues, 
partition_spec=partition_spec, schema=schema)
-        table_partitions.append(TablePartition(partition_key=partition_key, 
arrow_table_partition=partition_slice))
-    return table_partitions
-
-
-def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: 
pa.Table) -> List[TablePartition]:
-    """Based on the iceberg table partition spec, slice the arrow table into 
partitions with their keys.
-
-    Example:
-    Input:
-    An arrow table with partition key of ['n_legs', 'year'] and with data of
-    {'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
-     'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
-     'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", 
"Horse","Brittle stars", "Centipede"]}.
-    The algorithm:
-    Firstly we group the rows into partitions by sorting with sort order 
[('n_legs', 'descending'), ('year', 'descending')]
-    and null_placement of "at_end".
-    This gives the same table as raw input.
-    Then we sort_indices using reverse order of [('n_legs', 'descending'), 
('year', 'descending')]
-    and null_placement : "at_start".
-    This gives:
-    [8, 7, 4, 5, 6, 3, 1, 2, 0]
-    Based on this we get partition groups of indices:
-    [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 
'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, 
{'offset': 0, 'length': 1}]
-    We then retrieve the partition keys by offsets.
-    And slice the arrow table by offsets and lengths of each partition.
-    """
-    import pyarrow as pa
-
-    partition_columns: List[Tuple[PartitionField, NestedField]] = [
-        (partition_field, schema.find_field(partition_field.source_id)) for 
partition_field in spec.fields
-    ]
-    partition_values_table = pa.table({
-        str(partition.field_id): 
partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
-        for partition, field in partition_columns
-    })
-
-    # Sort by partitions
-    sort_indices = pa.compute.sort_indices(
-        partition_values_table,
-        sort_keys=[(col, "ascending") for col in 
partition_values_table.column_names],
-        null_placement="at_end",
-    ).to_pylist()
-    arrow_table = arrow_table.take(sort_indices)
-
-    # Get slice_instructions to group by partitions
-    partition_values_table = partition_values_table.take(sort_indices)
-    reversed_indices = pa.compute.sort_indices(
-        partition_values_table,
-        sort_keys=[(col, "descending") for col in 
partition_values_table.column_names],
-        null_placement="at_start",
-    ).to_pylist()
-    slice_instructions: List[Dict[str, Any]] = []
-    last = len(reversed_indices)
-    reversed_indices_size = len(reversed_indices)
-    ptr = 0
-    while ptr < reversed_indices_size:
-        group_size = last - reversed_indices[ptr]
-        offset = reversed_indices[ptr]
-        slice_instructions.append({"offset": offset, "length": group_size})
-        last = reversed_indices[ptr]
-        ptr = ptr + group_size
-
-    table_partitions: List[TablePartition] = 
_get_table_partitions(arrow_table, spec, schema, slice_instructions)
-
-    return table_partitions
-
-
 class _ManifestMergeManager(Generic[U]):
     _target_size_bytes: int
     _min_count_to_merge: int
diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py
index ecb946a9..1b946899 100644
--- a/tests/io/test_pyarrow.py
+++ b/tests/io/test_pyarrow.py
@@ -61,6 +61,7 @@ from pyiceberg.io.pyarrow import (
     PyArrowFileIO,
     StatsAggregator,
     _ConvertToArrowSchema,
+    _determine_partitions,
     _primitive_to_physical,
     _read_deletes,
     bin_pack_arrow_table,
@@ -69,11 +70,12 @@ from pyiceberg.io.pyarrow import (
     schema_to_pyarrow,
 )
 from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
-from pyiceberg.partitioning import PartitionSpec
+from pyiceberg.partitioning import PartitionField, PartitionSpec
 from pyiceberg.schema import Schema, make_compatible_name, visit
 from pyiceberg.table import FileScanTask, TableProperties
 from pyiceberg.table.metadata import TableMetadataV2
-from pyiceberg.typedef import UTF8
+from pyiceberg.transforms import IdentityTransform
+from pyiceberg.typedef import UTF8, Record
 from pyiceberg.types import (
     BinaryType,
     BooleanType,
@@ -1718,3 +1720,81 @@ def test_bin_pack_arrow_table(arrow_table_with_null: 
pa.Table) -> None:
     # and will produce half the number of files if we double the target size
     bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, 
target_file_size=arrow_table_with_null.nbytes * 2)
     assert len(list(bin_packed)) == 5
+
+
+def test_partition_for_demo() -> None:
+    test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), 
("animal", pa.string())])
+    test_schema = Schema(
+        NestedField(field_id=1, name="year", field_type=StringType(), 
required=False),
+        NestedField(field_id=2, name="n_legs", field_type=IntegerType(), 
required=True),
+        NestedField(field_id=3, name="animal", field_type=StringType(), 
required=False),
+        schema_id=1,
+    )
+    test_data = {
+        "year": [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021],
+        "n_legs": [2, 2, 2, 4, 4, 4, 4, 5, 100],
+        "animal": ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", 
"Horse", "Brittle stars", "Centipede"],
+    }
+    arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
+    partition_spec = PartitionSpec(
+        PartitionField(source_id=2, field_id=1002, 
transform=IdentityTransform(), name="n_legs_identity"),
+        PartitionField(source_id=1, field_id=1001, 
transform=IdentityTransform(), name="year_identity"),
+    )
+    result = _determine_partitions(partition_spec, test_schema, arrow_table)
+    assert {table_partition.partition_key.partition for table_partition in 
result} == {
+        Record(n_legs_identity=2, year_identity=2020),
+        Record(n_legs_identity=100, year_identity=2021),
+        Record(n_legs_identity=4, year_identity=2021),
+        Record(n_legs_identity=4, year_identity=2022),
+        Record(n_legs_identity=2, year_identity=2022),
+        Record(n_legs_identity=5, year_identity=2019),
+    }
+    assert (
+        pa.concat_tables([table_partition.arrow_table_partition for 
table_partition in result]).num_rows == arrow_table.num_rows
+    )
+
+
+def test_identity_partition_on_multi_columns() -> None:
+    test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", 
pa.int64()), ("animal", pa.string())])
+    test_schema = Schema(
+        NestedField(field_id=1, name="born_year", field_type=StringType(), 
required=False),
+        NestedField(field_id=2, name="n_legs", field_type=IntegerType(), 
required=True),
+        NestedField(field_id=3, name="animal", field_type=StringType(), 
required=False),
+        schema_id=1,
+    )
+    # 5 partitions, 6 unique row values, 12 rows
+    test_rows = [
+        (2021, 4, "Dog"),
+        (2022, 4, "Horse"),
+        (2022, 4, "Another Horse"),
+        (2021, 100, "Centipede"),
+        (None, 4, "Kirin"),
+        (2021, None, "Fish"),
+    ] * 2
+    expected = {Record(n_legs_identity=test_rows[i][1], 
year_identity=test_rows[i][0]) for i in range(len(test_rows))}
+    partition_spec = PartitionSpec(
+        PartitionField(source_id=2, field_id=1002, 
transform=IdentityTransform(), name="n_legs_identity"),
+        PartitionField(source_id=1, field_id=1001, 
transform=IdentityTransform(), name="year_identity"),
+    )
+    import random
+
+    # there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all
+    for _ in range(1000):
+        random.shuffle(test_rows)
+        test_data = {
+            "born_year": [row[0] for row in test_rows],
+            "n_legs": [row[1] for row in test_rows],
+            "animal": [row[2] for row in test_rows],
+        }
+        arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
+
+        result = _determine_partitions(partition_spec, test_schema, 
arrow_table)
+
+        assert {table_partition.partition_key.partition for table_partition in 
result} == expected
+        concatenated_arrow_table = 
pa.concat_tables([table_partition.arrow_table_partition for table_partition in 
result])
+        assert concatenated_arrow_table.num_rows == arrow_table.num_rows
+        assert concatenated_arrow_table.sort_by([
+            ("born_year", "ascending"),
+            ("n_legs", "ascending"),
+            ("animal", "ascending"),
+        ]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", 
"ascending"), ("animal", "ascending")])
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index d7c4ffee..31a8bbf4 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -64,7 +64,6 @@ from pyiceberg.table import (
     UpdateSchema,
     _apply_table_update,
     _check_schema_compatible,
-    _determine_partitions,
     _match_deletes_to_data_file,
     _TableMetadataUpdateContext,
     update_table_metadata,
@@ -88,7 +87,6 @@ from pyiceberg.transforms import (
     BucketTransform,
     IdentityTransform,
 )
-from pyiceberg.typedef import Record
 from pyiceberg.types import (
     BinaryType,
     BooleanType,
@@ -1248,85 +1246,3 @@ def test_serialize_commit_table_request() -> None:
 
     deserialized_request = 
CommitTableRequest.model_validate_json(request.model_dump_json())
     assert request == deserialized_request
-
-
-def test_partition_for_demo() -> None:
-    import pyarrow as pa
-
-    test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), 
("animal", pa.string())])
-    test_schema = Schema(
-        NestedField(field_id=1, name="year", field_type=StringType(), 
required=False),
-        NestedField(field_id=2, name="n_legs", field_type=IntegerType(), 
required=True),
-        NestedField(field_id=3, name="animal", field_type=StringType(), 
required=False),
-        schema_id=1,
-    )
-    test_data = {
-        "year": [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021],
-        "n_legs": [2, 2, 2, 4, 4, 4, 4, 5, 100],
-        "animal": ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", 
"Horse", "Brittle stars", "Centipede"],
-    }
-    arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
-    partition_spec = PartitionSpec(
-        PartitionField(source_id=2, field_id=1002, 
transform=IdentityTransform(), name="n_legs_identity"),
-        PartitionField(source_id=1, field_id=1001, 
transform=IdentityTransform(), name="year_identity"),
-    )
-    result = _determine_partitions(partition_spec, test_schema, arrow_table)
-    assert {table_partition.partition_key.partition for table_partition in 
result} == {
-        Record(n_legs_identity=2, year_identity=2020),
-        Record(n_legs_identity=100, year_identity=2021),
-        Record(n_legs_identity=4, year_identity=2021),
-        Record(n_legs_identity=4, year_identity=2022),
-        Record(n_legs_identity=2, year_identity=2022),
-        Record(n_legs_identity=5, year_identity=2019),
-    }
-    assert (
-        pa.concat_tables([table_partition.arrow_table_partition for 
table_partition in result]).num_rows == arrow_table.num_rows
-    )
-
-
-def test_identity_partition_on_multi_columns() -> None:
-    import pyarrow as pa
-
-    test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", 
pa.int64()), ("animal", pa.string())])
-    test_schema = Schema(
-        NestedField(field_id=1, name="born_year", field_type=StringType(), 
required=False),
-        NestedField(field_id=2, name="n_legs", field_type=IntegerType(), 
required=True),
-        NestedField(field_id=3, name="animal", field_type=StringType(), 
required=False),
-        schema_id=1,
-    )
-    # 5 partitions, 6 unique row values, 12 rows
-    test_rows = [
-        (2021, 4, "Dog"),
-        (2022, 4, "Horse"),
-        (2022, 4, "Another Horse"),
-        (2021, 100, "Centipede"),
-        (None, 4, "Kirin"),
-        (2021, None, "Fish"),
-    ] * 2
-    expected = {Record(n_legs_identity=test_rows[i][1], 
year_identity=test_rows[i][0]) for i in range(len(test_rows))}
-    partition_spec = PartitionSpec(
-        PartitionField(source_id=2, field_id=1002, 
transform=IdentityTransform(), name="n_legs_identity"),
-        PartitionField(source_id=1, field_id=1001, 
transform=IdentityTransform(), name="year_identity"),
-    )
-    import random
-
-    # there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all
-    for _ in range(1000):
-        random.shuffle(test_rows)
-        test_data = {
-            "born_year": [row[0] for row in test_rows],
-            "n_legs": [row[1] for row in test_rows],
-            "animal": [row[2] for row in test_rows],
-        }
-        arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
-
-        result = _determine_partitions(partition_spec, test_schema, 
arrow_table)
-
-        assert {table_partition.partition_key.partition for table_partition in 
result} == expected
-        concatenated_arrow_table = 
pa.concat_tables([table_partition.arrow_table_partition for table_partition in 
result])
-        assert concatenated_arrow_table.num_rows == arrow_table.num_rows
-        assert concatenated_arrow_table.sort_by([
-            ("born_year", "ascending"),
-            ("n_legs", "ascending"),
-            ("animal", "ascending"),
-        ]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", 
"ascending"), ("animal", "ascending")])

Reply via email to