This is an automated email from the ASF dual-hosted git repository.
fokko pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 4148edb5 Partitioned Append on Identity Transform (#555)
4148edb5 is described below
commit 4148edb5e28ae88024a55e0b112238e65b873957
Author: Adrian Qin <[email protected]>
AuthorDate: Fri Apr 5 03:52:16 2024 -0400
Partitioned Append on Identity Transform (#555)
* partitioned append on identity transform
* remove unnecessary fixture
* added null/empty table tests; fixed part of PR comments
* tests for unsupported transforms; unit tests for partition slicing
algorithm
* add a comprehensive partition unit test
* clean up
* move common fixtures utils to utils.py and conftest
* pull partitioned table fixtures into tests for more real-time feedback of
running test
* fix linting
* license
* save changes for swtiching codespaces
* part of the comment fixes
* fix one type error
* add support for timetype
* small fix for type hint
---
pyiceberg/io/pyarrow.py | 4 +-
pyiceberg/manifest.py | 27 +-
pyiceberg/partitioning.py | 10 +-
pyiceberg/table/__init__.py | 174 +++++++++-
pyiceberg/typedef.py | 4 +
tests/conftest.py | 29 +-
tests/integration/test_partitioning_key.py | 2 +-
.../test_writes/test_partitioned_writes.py | 386 +++++++++++++++++++++
tests/integration/{ => test_writes}/test_writes.py | 99 +-----
tests/integration/test_writes/utils.py | 85 +++++
tests/table/test_init.py | 89 ++++-
11 files changed, 763 insertions(+), 146 deletions(-)
diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py
index 738cd77b..06d03e21 100644
--- a/pyiceberg/io/pyarrow.py
+++ b/pyiceberg/io/pyarrow.py
@@ -1772,7 +1772,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata,
tasks: Iterator[WriteT
)
def write_parquet(task: WriteTask) -> DataFile:
- file_path =
f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
+ file_path =
f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_file_schema,
**parquet_writer_kwargs) as writer:
@@ -1787,7 +1787,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata,
tasks: Iterator[WriteT
content=DataFileContent.DATA,
file_path=file_path,
file_format=FileFormat.PARQUET,
- partition=Record(),
+ partition=task.partition_key.partition if task.partition_key else
Record(),
file_size_in_bytes=len(fo),
# After this has been fixed:
# https://github.com/apache/iceberg-python/issues/271
diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py
index 5277eed9..3b8138b6 100644
--- a/pyiceberg/manifest.py
+++ b/pyiceberg/manifest.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import math
from abc import ABC, abstractmethod
from enum import Enum
-from functools import singledispatch
from types import TracebackType
from typing import (
Any,
@@ -41,8 +40,6 @@ from pyiceberg.typedef import EMPTY_DICT, Record, TableVersion
from pyiceberg.types import (
BinaryType,
BooleanType,
- DateType,
- IcebergType,
IntegerType,
ListType,
LongType,
@@ -51,9 +48,6 @@ from pyiceberg.types import (
PrimitiveType,
StringType,
StructType,
- TimestampType,
- TimestamptzType,
- TimeType,
)
UNASSIGNED_SEQ = -1
@@ -283,31 +277,12 @@ DATA_FILE_TYPE: Dict[int, StructType] = {
}
-@singledispatch
-def partition_field_to_data_file_partition_field(partition_field_type:
IcebergType) -> PrimitiveType:
- raise TypeError(f"Unsupported partition field type:
{partition_field_type}")
-
-
-@partition_field_to_data_file_partition_field.register(LongType)
-@partition_field_to_data_file_partition_field.register(DateType)
-@partition_field_to_data_file_partition_field.register(TimeType)
-@partition_field_to_data_file_partition_field.register(TimestampType)
-@partition_field_to_data_file_partition_field.register(TimestamptzType)
-def _(partition_field_type: PrimitiveType) -> IntegerType:
- return IntegerType()
-
-
-@partition_field_to_data_file_partition_field.register(PrimitiveType)
-def _(partition_field_type: PrimitiveType) -> PrimitiveType:
- return partition_field_type
-
-
def data_file_with_partition(partition_type: StructType, format_version:
TableVersion) -> StructType:
data_file_partition_type = StructType(*[
NestedField(
field_id=field.field_id,
name=field.name,
-
field_type=partition_field_to_data_file_partition_field(field.field_type),
+ field_type=field.field_type,
required=field.required,
)
for field in partition_type.fields
diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py
index 16f15882..a3cf2553 100644
--- a/pyiceberg/partitioning.py
+++ b/pyiceberg/partitioning.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
-from datetime import date, datetime
+from datetime import date, datetime, time
from functools import cached_property, singledispatch
from typing import (
Any,
@@ -62,9 +62,10 @@ from pyiceberg.types import (
StructType,
TimestampType,
TimestamptzType,
+ TimeType,
UUIDType,
)
-from pyiceberg.utils.datetime import date_to_days, datetime_to_micros
+from pyiceberg.utils.datetime import date_to_days, datetime_to_micros,
time_to_micros
INITIAL_PARTITION_SPEC_ID = 0
PARTITION_FIELD_ID_START: int = 1000
@@ -431,6 +432,11 @@ def _(type: IcebergType, value: Optional[date]) ->
Optional[int]:
return date_to_days(value) if value is not None else None
+@_to_partition_representation.register(TimeType)
+def _(type: IcebergType, value: Optional[time]) -> Optional[int]:
+ return time_to_micros(value) if value is not None else None
+
+
@_to_partition_representation.register(UUIDType)
def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]:
return str(value) if value is not None else None
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index e183d827..2dbc32d8 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -16,13 +16,13 @@
# under the License.
from __future__ import annotations
-import datetime
import itertools
import uuid
import warnings
from abc import ABC, abstractmethod
from copy import copy
from dataclasses import dataclass
+from datetime import datetime
from enum import Enum
from functools import cached_property, singledispatch
from itertools import chain
@@ -79,6 +79,8 @@ from pyiceberg.partitioning import (
PARTITION_FIELD_ID_START,
UNPARTITIONED_PARTITION_SPEC,
PartitionField,
+ PartitionFieldValue,
+ PartitionKey,
PartitionSpec,
_PartitionNameGenerator,
_visit_partition_field,
@@ -373,8 +375,11 @@ class Transaction:
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")
- if len(self._table.spec().fields) > 0:
- raise ValueError("Cannot write to partitioned tables")
+ supported_transforms = {IdentityTransform}
+ if not all(type(field.transform) in supported_transforms for field in
self.table_metadata.spec().fields):
+ raise ValueError(
+ f"All transforms are not supported, expected:
{supported_transforms}, but get: {[str(field) for field in
self.table_metadata.spec().fields if field.transform not in
supported_transforms]}."
+ )
_check_schema_compatible(self._table.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
@@ -897,7 +902,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata:
TableMetadata, context: _Tabl
if update.ref_name == MAIN_BRANCH:
metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id
if "last_updated_ms" not in metadata_updates:
- metadata_updates["last_updated_ms"] =
datetime_to_millis(datetime.datetime.now().astimezone())
+ metadata_updates["last_updated_ms"] =
datetime_to_millis(datetime.now().astimezone())
metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [
SnapshotLogEntry(
@@ -2646,16 +2651,23 @@ def _add_and_move_fields(
class WriteTask:
write_uuid: uuid.UUID
task_id: int
+ schema: Schema
record_batches: List[pa.RecordBatch]
sort_order_id: Optional[int] = None
-
- # Later to be extended with partition information
+ partition_key: Optional[PartitionKey] = None
def generate_data_file_filename(self, extension: str) -> str:
# Mimics the behavior in the Java API:
#
https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101
return f"00000-{self.task_id}-{self.write_uuid}.{extension}"
+ def generate_data_file_path(self, extension: str) -> str:
+ if self.partition_key:
+ file_path =
f"{self.partition_key.to_path()}/{self.generate_data_file_filename(extension)}"
+ return file_path
+ else:
+ return self.generate_data_file_filename(extension)
+
@dataclass(frozen=True)
class AddFileTask:
@@ -2683,25 +2695,40 @@ def _dataframe_to_data_files(
"""
from pyiceberg.io.pyarrow import bin_pack_arrow_table, write_file
- if len([spec for spec in table_metadata.partition_specs if spec.spec_id !=
0]) > 0:
- raise ValueError("Cannot write to partitioned tables")
-
counter = itertools.count(0)
write_uuid = write_uuid or uuid.uuid4()
-
- target_file_size = PropertyUtil.property_as_int(
+ target_file_size: int = PropertyUtil.property_as_int( # type: ignore #
The property is set with non-None value.
properties=table_metadata.properties,
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
)
- # This is an iter, so we don't have to materialize everything every time
- # This will be more relevant when we start doing partitioned writes
- yield from write_file(
- io=io,
- table_metadata=table_metadata,
- tasks=iter([WriteTask(write_uuid, next(counter), batches) for batches
in bin_pack_arrow_table(df, target_file_size)]), # type: ignore
- )
+ if len(table_metadata.spec().fields) > 0:
+ partitions = _determine_partitions(spec=table_metadata.spec(),
schema=table_metadata.schema(), arrow_table=df)
+ yield from write_file(
+ io=io,
+ table_metadata=table_metadata,
+ tasks=iter([
+ WriteTask(
+ write_uuid=write_uuid,
+ task_id=next(counter),
+ record_batches=batches,
+ partition_key=partition.partition_key,
+ schema=table_metadata.schema(),
+ )
+ for partition in partitions
+ for batches in
bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
+ ]),
+ )
+ else:
+ yield from write_file(
+ io=io,
+ table_metadata=table_metadata,
+ tasks=iter([
+ WriteTask(write_uuid=write_uuid, task_id=next(counter),
record_batches=batches, schema=table_metadata.schema())
+ for batches in bin_pack_arrow_table(df, target_file_size)
+ ]),
+ )
def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths:
List[str], io: FileIO) -> Iterable[DataFile]:
@@ -3253,7 +3280,7 @@ class InspectTable:
additional_properties = None
snapshots.append({
- 'committed_at':
datetime.datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0),
+ 'committed_at':
datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0),
'snapshot_id': snapshot.snapshot_id,
'parent_id': snapshot.parent_snapshot_id,
'operation': str(operation),
@@ -3388,3 +3415,112 @@ class InspectTable:
entries,
schema=entries_schema,
)
+
+
+@dataclass(frozen=True)
+class TablePartition:
+ partition_key: PartitionKey
+ arrow_table_partition: pa.Table
+
+
+def _get_partition_sort_order(partition_columns: list[str], reverse: bool =
False) -> dict[str, Any]:
+ order = 'ascending' if not reverse else 'descending'
+ null_placement = 'at_start' if reverse else 'at_end'
+ return {'sort_keys': [(column_name, order) for column_name in
partition_columns], 'null_placement': null_placement}
+
+
+def group_by_partition_scheme(arrow_table: pa.Table, partition_columns:
list[str]) -> pa.Table:
+ """Given a table, sort it by current partition scheme."""
+ # only works for identity for now
+ sort_options = _get_partition_sort_order(partition_columns, reverse=False)
+ sorted_arrow_table =
arrow_table.sort_by(sorting=sort_options['sort_keys'],
null_placement=sort_options['null_placement'])
+ return sorted_arrow_table
+
+
+def get_partition_columns(
+ spec: PartitionSpec,
+ schema: Schema,
+) -> list[str]:
+ partition_cols = []
+ for partition_field in spec.fields:
+ column_name = schema.find_column_name(partition_field.source_id)
+ if not column_name:
+ raise ValueError(f"{partition_field=} could not be found in
{schema}.")
+ partition_cols.append(column_name)
+ return partition_cols
+
+
+def _get_table_partitions(
+ arrow_table: pa.Table,
+ partition_spec: PartitionSpec,
+ schema: Schema,
+ slice_instructions: list[dict[str, Any]],
+) -> list[TablePartition]:
+ sorted_slice_instructions = sorted(slice_instructions, key=lambda x:
x['offset'])
+
+ partition_fields = partition_spec.fields
+
+ offsets = [inst["offset"] for inst in sorted_slice_instructions]
+ projected_and_filtered = {
+ partition_field.source_id:
arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
+ .take(offsets)
+ .to_pylist()
+ for partition_field in partition_fields
+ }
+
+ table_partitions = []
+ for idx, inst in enumerate(sorted_slice_instructions):
+ partition_slice = arrow_table.slice(**inst)
+ fieldvalues = [
+ PartitionFieldValue(partition_field,
projected_and_filtered[partition_field.source_id][idx])
+ for partition_field in partition_fields
+ ]
+ partition_key = PartitionKey(raw_partition_field_values=fieldvalues,
partition_spec=partition_spec, schema=schema)
+ table_partitions.append(TablePartition(partition_key=partition_key,
arrow_table_partition=partition_slice))
+ return table_partitions
+
+
+def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table:
pa.Table) -> List[TablePartition]:
+ """Based on the iceberg table partition spec, slice the arrow table into
partitions with their keys.
+
+ Example:
+ Input:
+ An arrow table with partition key of ['n_legs', 'year'] and with data of
+ {'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
+ 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
+ 'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse",
"Horse","Brittle stars", "Centipede"]}.
+ The algrithm:
+ Firstly we group the rows into partitions by sorting with sort order
[('n_legs', 'descending'), ('year', 'descending')]
+ and null_placement of "at_end".
+ This gives the same table as raw input.
+ Then we sort_indices using reverse order of [('n_legs', 'descending'),
('year', 'descending')]
+ and null_placement : "at_start".
+ This gives:
+ [8, 7, 4, 5, 6, 3, 1, 2, 0]
+ Based on this we get partition groups of indices:
+ [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4,
'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2},
{'offset': 0, 'length': 1}]
+ We then retrieve the partition keys by offsets.
+ And slice the arrow table by offsets and lengths of each partition.
+ """
+ import pyarrow as pa
+
+ partition_columns = get_partition_columns(spec=spec, schema=schema)
+ arrow_table = group_by_partition_scheme(arrow_table, partition_columns)
+
+ reversing_sort_order_options =
_get_partition_sort_order(partition_columns, reverse=True)
+ reversed_indices = pa.compute.sort_indices(arrow_table,
**reversing_sort_order_options).to_pylist()
+
+ slice_instructions: list[dict[str, Any]] = []
+ last = len(reversed_indices)
+ reversed_indices_size = len(reversed_indices)
+ ptr = 0
+ while ptr < reversed_indices_size:
+ group_size = last - reversed_indices[ptr]
+ offset = reversed_indices[ptr]
+ slice_instructions.append({"offset": offset, "length": group_size})
+ last = reversed_indices[ptr]
+ ptr = ptr + group_size
+
+ table_partitions: list[TablePartition] =
_get_table_partitions(arrow_table, spec, schema, slice_instructions)
+
+ return table_partitions
diff --git a/pyiceberg/typedef.py b/pyiceberg/typedef.py
index 4bed386c..6ccf9526 100644
--- a/pyiceberg/typedef.py
+++ b/pyiceberg/typedef.py
@@ -202,5 +202,9 @@ class Record(StructProtocol):
"""Return values of all the fields of the Record class except those
specified in skip_fields."""
return [self.__getattribute__(v) if hasattr(self, v) else None for v
in self._position_to_field_name]
+ def __hash__(self) -> int:
+ """Return hash value of the Record class."""
+ return hash(str(self))
+
TableVersion: TypeAlias = Literal[1, 2]
diff --git a/tests/conftest.py b/tests/conftest.py
index aa09517b..4a820fed 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -30,7 +30,7 @@ import re
import socket
import string
import uuid
-from datetime import date, datetime
+from datetime import date, datetime, timezone
from pathlib import Path
from random import choice
from tempfile import TemporaryDirectory
@@ -1999,8 +1999,13 @@ TEST_DATA_WITH_NULL = {
'long': [1, None, 9],
'float': [0.0, None, 0.9],
'double': [0.0, None, 0.9],
+ # 'time': [1_000_000, None, 3_000_000], # Example times: 1s, none, and 3s
past midnight #Spark does not support time fields
'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1,
19, 25, 00)],
- 'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3,
1, 19, 25, 00)],
+ 'timestamptz': [
+ datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
+ None,
+ datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
+ ],
'date': [date(2023, 1, 1), None, date(2023, 3, 1)],
# Not supported by Spark
# 'time': [time(1, 22, 0), None, time(19, 25, 0)],
@@ -2027,6 +2032,8 @@ def pa_schema() -> "pa.Schema":
("long", pa.int64()),
("float", pa.float32()),
("double", pa.float64()),
+ # Not supported by Spark
+ # ("time", pa.time64('us')),
("timestamp", pa.timestamp(unit="us")),
("timestamptz", pa.timestamp(unit="us", tz="UTC")),
("date", pa.date32()),
@@ -2041,7 +2048,23 @@ def pa_schema() -> "pa.Schema":
@pytest.fixture(scope="session")
def arrow_table_with_null(pa_schema: "pa.Schema") -> "pa.Table":
+ """Pyarrow table with all kinds of columns."""
import pyarrow as pa
- """Pyarrow table with all kinds of columns."""
return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema)
+
+
[email protected](scope="session")
+def arrow_table_without_data(pa_schema: "pa.Schema") -> "pa.Table":
+ """Pyarrow table without data."""
+ import pyarrow as pa
+
+ return pa.Table.from_pylist([], schema=pa_schema)
+
+
[email protected](scope="session")
+def arrow_table_with_only_nulls(pa_schema: "pa.Schema") -> "pa.Table":
+ """Pyarrow table with only null values."""
+ import pyarrow as pa
+
+ return pa.Table.from_pylist([{}, {}], schema=pa_schema)
diff --git a/tests/integration/test_partitioning_key.py
b/tests/integration/test_partitioning_key.py
index 12056bac..d89ecaf2 100644
--- a/tests/integration/test_partitioning_key.py
+++ b/tests/integration/test_partitioning_key.py
@@ -749,7 +749,7 @@ def test_partition_key(
# key.to_path() generates the hive partitioning part of the to-write
parquet file path
assert key.to_path() == expected_hive_partition_path_slice
- # Justify expected values are not made up but conform to spark behaviors
+ # Justify expected values are not made up but conforming to spark behaviors
if spark_create_table_sql_for_justification is not None and
spark_data_insert_sql_for_justification is not None:
try:
spark.sql(f"drop table {identifier}")
diff --git a/tests/integration/test_writes/test_partitioned_writes.py
b/tests/integration/test_writes/test_partitioned_writes.py
new file mode 100644
index 00000000..d84b9745
--- /dev/null
+++ b/tests/integration/test_writes/test_partitioned_writes.py
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint:disable=redefined-outer-name
+
+import pyarrow as pa
+import pytest
+from pyspark.sql import SparkSession
+
+from pyiceberg.catalog import Catalog
+from pyiceberg.exceptions import NoSuchTableError
+from pyiceberg.partitioning import PartitionField, PartitionSpec
+from pyiceberg.transforms import (
+ BucketTransform,
+ DayTransform,
+ HourTransform,
+ IdentityTransform,
+ MonthTransform,
+ TruncateTransform,
+ YearTransform,
+)
+from tests.conftest import TEST_DATA_WITH_NULL
+from utils import TABLE_SCHEMA, _create_table
+
+
[email protected]
[email protected](
+ "part_col", ['int', 'bool', 'string', "string_long", "long", "float",
"double", "date", 'timestamp', 'timestamptz', 'binary']
+)
[email protected]("format_version", [1, 2])
+def test_query_filter_null_partitioned(
+ session_catalog: Catalog, spark: SparkSession, arrow_table_with_null:
pa.Table, part_col: str, format_version: int
+) -> None:
+ # Given
+ identifier =
f"default.arrow_table_v{format_version}_with_null_partitioned_on_col_{part_col}"
+ nested_field = TABLE_SCHEMA.find_field(part_col)
+ partition_spec = PartitionSpec(
+ PartitionField(source_id=nested_field.field_id, field_id=1001,
transform=IdentityTransform(), name=part_col)
+ )
+
+ # When
+ tbl = _create_table(
+ session_catalog=session_catalog,
+ identifier=identifier,
+ properties={"format-version": str(format_version)},
+ data=[arrow_table_with_null],
+ partition_spec=partition_spec,
+ )
+
+ # Then
+ assert tbl.format_version == format_version, f"Expected v{format_version},
got: v{tbl.format_version}"
+ df = spark.table(identifier)
+ assert df.count() == 3, f"Expected 3 total rows for {identifier}"
+ for col in TEST_DATA_WITH_NULL.keys():
+ assert df.where(f"{col} is not null").count() == 2, f"Expected 2
non-null rows for {col}"
+ assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row
for {col} is null"
+
+
[email protected]
[email protected](
+ "part_col", ['int', 'bool', 'string', "string_long", "long", "float",
"double", "date", 'timestamp', 'timestamptz', 'binary']
+)
[email protected]("format_version", [1, 2])
+def test_query_filter_without_data_partitioned(
+ session_catalog: Catalog, spark: SparkSession, arrow_table_without_data:
pa.Table, part_col: str, format_version: int
+) -> None:
+ # Given
+ identifier =
f"default.arrow_table_v{format_version}_without_data_partitioned_on_col_{part_col}"
+ nested_field = TABLE_SCHEMA.find_field(part_col)
+ partition_spec = PartitionSpec(
+ PartitionField(source_id=nested_field.field_id, field_id=1001,
transform=IdentityTransform(), name=part_col)
+ )
+
+ # When
+ tbl = _create_table(
+ session_catalog=session_catalog,
+ identifier=identifier,
+ properties={"format-version": str(format_version)},
+ data=[arrow_table_without_data],
+ partition_spec=partition_spec,
+ )
+
+ # Then
+ assert tbl.format_version == format_version, f"Expected v{format_version},
got: v{tbl.format_version}"
+ df = spark.table(identifier)
+ for col in TEST_DATA_WITH_NULL.keys():
+ assert df.where(f"{col} is null").count() == 0, f"Expected 0 row for
{col}"
+ assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row
for {col}"
+
+
[email protected]
[email protected](
+ "part_col", ['int', 'bool', 'string', "string_long", "long", "float",
"double", "date", 'timestamp', 'timestamptz', 'binary']
+)
[email protected]("format_version", [1, 2])
+def test_query_filter_only_nulls_partitioned(
+ session_catalog: Catalog, spark: SparkSession,
arrow_table_with_only_nulls: pa.Table, part_col: str, format_version: int
+) -> None:
+ # Given
+ identifier =
f"default.arrow_table_v{format_version}_with_only_nulls_partitioned_on_col_{part_col}"
+ nested_field = TABLE_SCHEMA.find_field(part_col)
+ partition_spec = PartitionSpec(
+ PartitionField(source_id=nested_field.field_id, field_id=1001,
transform=IdentityTransform(), name=part_col)
+ )
+
+ # When
+ tbl = _create_table(
+ session_catalog=session_catalog,
+ identifier=identifier,
+ properties={"format-version": str(format_version)},
+ data=[arrow_table_with_only_nulls],
+ partition_spec=partition_spec,
+ )
+
+ # Then
+ assert tbl.format_version == format_version, f"Expected v{format_version},
got: v{tbl.format_version}"
+ df = spark.table(identifier)
+ for col in TEST_DATA_WITH_NULL.keys():
+ assert df.where(f"{col} is null").count() == 2, f"Expected 2 row for
{col}"
+ assert df.where(f"{col} is not null").count() == 0, f"Expected 0 rows
for {col}"
+
+
[email protected]
[email protected](
+ "part_col", ['int', 'bool', 'string', "string_long", "long", "float",
"double", "date", "timestamptz", "timestamp", "binary"]
+)
[email protected]("format_version", [1, 2])
+def test_query_filter_appended_null_partitioned(
+ session_catalog: Catalog, spark: SparkSession, arrow_table_with_null:
pa.Table, part_col: str, format_version: int
+) -> None:
+ # Given
+ identifier =
f"default.arrow_table_v{format_version}_appended_with_null_partitioned_on_col_{part_col}"
+ nested_field = TABLE_SCHEMA.find_field(part_col)
+ partition_spec = PartitionSpec(
+ PartitionField(source_id=nested_field.field_id, field_id=1001,
transform=IdentityTransform(), name=part_col)
+ )
+
+ # When
+ tbl = _create_table(
+ session_catalog=session_catalog,
+ identifier=identifier,
+ properties={"format-version": str(format_version)},
+ data=[],
+ partition_spec=partition_spec,
+ )
+ # Append with arrow_table_1 with lines [A,B,C] and then arrow_table_2 with
lines[A,B,C,A,B,C]
+ tbl.append(arrow_table_with_null)
+ tbl.append(pa.concat_tables([arrow_table_with_null,
arrow_table_with_null]))
+
+ # Then
+ assert tbl.format_version == format_version, f"Expected v{format_version},
got: v{tbl.format_version}"
+ df = spark.table(identifier)
+ for col in TEST_DATA_WITH_NULL.keys():
+ df = spark.table(identifier)
+ assert df.where(f"{col} is not null").count() == 6, f"Expected 6
non-null rows for {col}"
+ assert df.where(f"{col} is null").count() == 3, f"Expected 3 null rows
for {col}"
+ # expecting 6 files: first append with [A], [B], [C], second append with
[A, A], [B, B], [C, C]
+ rows = spark.sql(f"select partition from {identifier}.files").collect()
+ assert len(rows) == 6
+
+
[email protected]
[email protected](
+ "part_col", ['int', 'bool', 'string', "string_long", "long", "float",
"double", "date", "timestamptz", "timestamp", "binary"]
+)
+def test_query_filter_v1_v2_append_null(
+ session_catalog: Catalog, spark: SparkSession, arrow_table_with_null:
pa.Table, part_col: str
+) -> None:
+ # Given
+ identifier =
f"default.arrow_table_v1_v2_appended_with_null_partitioned_on_col_{part_col}"
+ nested_field = TABLE_SCHEMA.find_field(part_col)
+ partition_spec = PartitionSpec(
+ PartitionField(source_id=nested_field.field_id, field_id=1001,
transform=IdentityTransform(), name=part_col)
+ )
+
+ # When
+ tbl = _create_table(
+ session_catalog=session_catalog,
+ identifier=identifier,
+ properties={"format-version": "1"},
+ data=[],
+ partition_spec=partition_spec,
+ )
+ tbl.append(arrow_table_with_null)
+
+ # Then
+ assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}"
+
+ # When
+ with tbl.transaction() as tx:
+ tx.upgrade_table_version(format_version=2)
+
+ tbl.append(arrow_table_with_null)
+
+ # Then
+ assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}"
+ for col in TEST_DATA_WITH_NULL.keys(): # type: ignore
+ df = spark.table(identifier)
+ assert df.where(f"{col} is not null").count() == 4, f"Expected 4
non-null rows for {col}"
+ assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows
for {col}"
+
+
[email protected]
+def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog,
arrow_table_with_null: pa.Table) -> None:
+ identifier = "default.arrow_table_summaries"
+
+ try:
+ session_catalog.drop_table(identifier=identifier)
+ except NoSuchTableError:
+ pass
+ tbl = session_catalog.create_table(
+ identifier=identifier,
+ schema=TABLE_SCHEMA,
+ partition_spec=PartitionSpec(PartitionField(source_id=4,
field_id=1001, transform=IdentityTransform(), name="int")),
+ properties={'format-version': '2'},
+ )
+
+ tbl.append(arrow_table_with_null)
+ tbl.append(arrow_table_with_null)
+
+ rows = spark.sql(
+ f"""
+ SELECT operation, summary
+ FROM {identifier}.snapshots
+ ORDER BY committed_at ASC
+ """
+ ).collect()
+
+ operations = [row.operation for row in rows]
+ assert operations == ['append', 'append']
+
+ summaries = [row.summary for row in rows]
+ assert summaries[0] == {
+ 'changed-partition-count': '3',
+ 'added-data-files': '3',
+ 'added-files-size': '15029',
+ 'added-records': '3',
+ 'total-data-files': '3',
+ 'total-delete-files': '0',
+ 'total-equality-deletes': '0',
+ 'total-files-size': '15029',
+ 'total-position-deletes': '0',
+ 'total-records': '3',
+ }
+
+ assert summaries[1] == {
+ 'changed-partition-count': '3',
+ 'added-data-files': '3',
+ 'added-files-size': '15029',
+ 'added-records': '3',
+ 'total-data-files': '6',
+ 'total-delete-files': '0',
+ 'total-equality-deletes': '0',
+ 'total-files-size': '30058',
+ 'total-position-deletes': '0',
+ 'total-records': '6',
+ }
+
+
[email protected]
+def test_data_files_with_table_partitioned_with_null(
+ spark: SparkSession, session_catalog: Catalog, arrow_table_with_null:
pa.Table
+) -> None:
+ identifier = "default.arrow_data_files"
+
+ try:
+ session_catalog.drop_table(identifier=identifier)
+ except NoSuchTableError:
+ pass
+ tbl = session_catalog.create_table(
+ identifier=identifier,
+ schema=TABLE_SCHEMA,
+ partition_spec=PartitionSpec(PartitionField(source_id=4,
field_id=1001, transform=IdentityTransform(), name="int")),
+ properties={'format-version': '1'},
+ )
+
+ tbl.append(arrow_table_with_null)
+ tbl.append(arrow_table_with_null)
+
+ # added_data_files_count, existing_data_files_count,
deleted_data_files_count
+ rows = spark.sql(
+ f"""
+ SELECT added_data_files_count, existing_data_files_count,
deleted_data_files_count
+ FROM {identifier}.all_manifests
+ """
+ ).collect()
+
+ assert [row.added_data_files_count for row in rows] == [3, 3, 3]
+ assert [row.existing_data_files_count for row in rows] == [
+ 0,
+ 0,
+ 0,
+ ]
+ assert [row.deleted_data_files_count for row in rows] == [0, 0, 0]
+
+
[email protected]
+def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) ->
None:
+ identifier = "default.arrow_data_files"
+
+ try:
+ session_catalog.drop_table(identifier=identifier)
+ except NoSuchTableError:
+ pass
+
+ tbl = session_catalog.create_table(
+ identifier=identifier,
+ schema=TABLE_SCHEMA,
+ partition_spec=PartitionSpec(PartitionField(source_id=4,
field_id=1001, transform=IdentityTransform(), name="int")),
+ properties={'format-version': '1'},
+ )
+
+ with pytest.raises(ValueError, match="Expected PyArrow table, got: not a
df"):
+ tbl.append("not a df")
+
+
[email protected]
[email protected](
+ "spec",
+ [
+ # mixed with non-identity is not supported
+ (
+ PartitionSpec(
+ PartitionField(source_id=4, field_id=1001,
transform=BucketTransform(2), name="int_bucket"),
+ PartitionField(source_id=1, field_id=1002,
transform=IdentityTransform(), name="bool"),
+ )
+ ),
+ # none of non-identity is supported
+ (PartitionSpec(PartitionField(source_id=4, field_id=1001,
transform=BucketTransform(2), name="int_bucket"))),
+ (PartitionSpec(PartitionField(source_id=5, field_id=1001,
transform=BucketTransform(2), name="long_bucket"))),
+ (PartitionSpec(PartitionField(source_id=10, field_id=1001,
transform=BucketTransform(2), name="date_bucket"))),
+ (PartitionSpec(PartitionField(source_id=8, field_id=1001,
transform=BucketTransform(2), name="timestamp_bucket"))),
+ (PartitionSpec(PartitionField(source_id=9, field_id=1001,
transform=BucketTransform(2), name="timestamptz_bucket"))),
+ (PartitionSpec(PartitionField(source_id=2, field_id=1001,
transform=BucketTransform(2), name="string_bucket"))),
+ (PartitionSpec(PartitionField(source_id=12, field_id=1001,
transform=BucketTransform(2), name="fixed_bucket"))),
+ (PartitionSpec(PartitionField(source_id=11, field_id=1001,
transform=BucketTransform(2), name="binary_bucket"))),
+ (PartitionSpec(PartitionField(source_id=4, field_id=1001,
transform=TruncateTransform(2), name="int_trunc"))),
+ (PartitionSpec(PartitionField(source_id=5, field_id=1001,
transform=TruncateTransform(2), name="long_trunc"))),
+ (PartitionSpec(PartitionField(source_id=2, field_id=1001,
transform=TruncateTransform(2), name="string_trunc"))),
+ (PartitionSpec(PartitionField(source_id=11, field_id=1001,
transform=TruncateTransform(2), name="binary_trunc"))),
+ (PartitionSpec(PartitionField(source_id=8, field_id=1001,
transform=YearTransform(), name="timestamp_year"))),
+ (PartitionSpec(PartitionField(source_id=9, field_id=1001,
transform=YearTransform(), name="timestamptz_year"))),
+ (PartitionSpec(PartitionField(source_id=10, field_id=1001,
transform=YearTransform(), name="date_year"))),
+ (PartitionSpec(PartitionField(source_id=8, field_id=1001,
transform=MonthTransform(), name="timestamp_month"))),
+ (PartitionSpec(PartitionField(source_id=9, field_id=1001,
transform=MonthTransform(), name="timestamptz_month"))),
+ (PartitionSpec(PartitionField(source_id=10, field_id=1001,
transform=MonthTransform(), name="date_month"))),
+ (PartitionSpec(PartitionField(source_id=8, field_id=1001,
transform=DayTransform(), name="timestamp_day"))),
+ (PartitionSpec(PartitionField(source_id=9, field_id=1001,
transform=DayTransform(), name="timestamptz_day"))),
+ (PartitionSpec(PartitionField(source_id=10, field_id=1001,
transform=DayTransform(), name="date_day"))),
+ (PartitionSpec(PartitionField(source_id=8, field_id=1001,
transform=HourTransform(), name="timestamp_hour"))),
+ (PartitionSpec(PartitionField(source_id=9, field_id=1001,
transform=HourTransform(), name="timestamptz_hour"))),
+ (PartitionSpec(PartitionField(source_id=10, field_id=1001,
transform=HourTransform(), name="date_hour"))),
+ ],
+)
+def test_unsupported_transform(
+ spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog,
arrow_table_with_null: pa.Table
+) -> None:
+ identifier = "default.unsupported_transform"
+
+ try:
+ session_catalog.drop_table(identifier=identifier)
+ except NoSuchTableError:
+ pass
+
+ tbl = session_catalog.create_table(
+ identifier=identifier,
+ schema=TABLE_SCHEMA,
+ partition_spec=spec,
+ properties={'format-version': '1'},
+ )
+
+ with pytest.raises(ValueError, match="All transforms are not supported.*"):
+ tbl.append(arrow_table_with_null)
diff --git a/tests/integration/test_writes.py
b/tests/integration/test_writes/test_writes.py
similarity index 88%
rename from tests/integration/test_writes.py
rename to tests/integration/test_writes/test_writes.py
index e950fb43..62d3bb11 100644
--- a/tests/integration/test_writes.py
+++ b/tests/integration/test_writes/test_writes.py
@@ -18,10 +18,9 @@
import math
import os
import time
-import uuid
from datetime import date, datetime
from pathlib import Path
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict
from urllib.parse import urlparse
import pyarrow as pa
@@ -36,93 +35,9 @@ from pytest_mock.plugin import MockerFixture
from pyiceberg.catalog import Catalog
from pyiceberg.catalog.sql import SqlCatalog
from pyiceberg.exceptions import NoSuchTableError
-from pyiceberg.schema import Schema
-from pyiceberg.table import Table, TableProperties, _dataframe_to_data_files
-from pyiceberg.typedef import Properties
-from pyiceberg.types import (
- BinaryType,
- BooleanType,
- DateType,
- DoubleType,
- FixedType,
- FloatType,
- IntegerType,
- LongType,
- NestedField,
- StringType,
- TimestampType,
- TimestamptzType,
-)
-
-TEST_DATA_WITH_NULL = {
- 'bool': [False, None, True],
- 'string': ['a', None, 'z'],
- # Go over the 16 bytes to kick in truncation
- 'string_long': ['a' * 22, None, 'z' * 22],
- 'int': [1, None, 9],
- 'long': [1, None, 9],
- 'float': [0.0, None, 0.9],
- 'double': [0.0, None, 0.9],
- 'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1,
19, 25, 00)],
- 'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3,
1, 19, 25, 00)],
- 'date': [date(2023, 1, 1), None, date(2023, 3, 1)],
- # Not supported by Spark
- # 'time': [time(1, 22, 0), None, time(19, 25, 0)],
- # Not natively supported by Arrow
- # 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None,
uuid.UUID('11111111-1111-1111-1111-111111111111').bytes],
- 'binary': [b'\01', None, b'\22'],
- 'fixed': [
- uuid.UUID('00000000-0000-0000-0000-000000000000').bytes,
- None,
- uuid.UUID('11111111-1111-1111-1111-111111111111').bytes,
- ],
-}
-
-TABLE_SCHEMA = Schema(
- NestedField(field_id=1, name="bool", field_type=BooleanType(),
required=False),
- NestedField(field_id=2, name="string", field_type=StringType(),
required=False),
- NestedField(field_id=3, name="string_long", field_type=StringType(),
required=False),
- NestedField(field_id=4, name="int", field_type=IntegerType(),
required=False),
- NestedField(field_id=5, name="long", field_type=LongType(),
required=False),
- NestedField(field_id=6, name="float", field_type=FloatType(),
required=False),
- NestedField(field_id=7, name="double", field_type=DoubleType(),
required=False),
- NestedField(field_id=8, name="timestamp", field_type=TimestampType(),
required=False),
- NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(),
required=False),
- NestedField(field_id=10, name="date", field_type=DateType(),
required=False),
- # NestedField(field_id=11, name="time", field_type=TimeType(),
required=False),
- # NestedField(field_id=12, name="uuid", field_type=UuidType(),
required=False),
- NestedField(field_id=12, name="binary", field_type=BinaryType(),
required=False),
- NestedField(field_id=13, name="fixed", field_type=FixedType(16),
required=False),
-)
-
-
[email protected](scope="session")
-def arrow_table_without_data(pa_schema: pa.Schema) -> pa.Table:
- """PyArrow table with all kinds of columns"""
- return pa.Table.from_pylist([], schema=pa_schema)
-
-
[email protected](scope="session")
-def arrow_table_with_only_nulls(pa_schema: pa.Schema) -> pa.Table:
- """PyArrow table with all kinds of columns"""
- return pa.Table.from_pylist([{}, {}], schema=pa_schema)
-
-
-def _create_table(
- session_catalog: Catalog, identifier: str, properties: Properties, data:
Optional[List[pa.Table]] = None
-) -> Table:
- try:
- session_catalog.drop_table(identifier=identifier)
- except NoSuchTableError:
- pass
-
- tbl = session_catalog.create_table(identifier=identifier,
schema=TABLE_SCHEMA, properties=properties)
-
- if data:
- for d in data:
- tbl.append(d)
-
- return tbl
+from pyiceberg.table import TableProperties, _dataframe_to_data_files
+from tests.conftest import TEST_DATA_WITH_NULL
+from utils import _create_table
@pytest.fixture(scope="session", autouse=True)
@@ -219,7 +134,7 @@ def test_query_filter_without_data(spark: SparkSession,
col: str, format_version
identifier = f"default.arrow_table_v{format_version}_without_data"
df = spark.table(identifier)
assert df.where(f"{col} is null").count() == 0, f"Expected 0 row for {col}"
- assert df.where(f"{col} is not null").count() == 0, f"Expected 0 rows for
{col}"
+ assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row for
{col}"
@pytest.mark.integration
@@ -228,8 +143,8 @@ def test_query_filter_without_data(spark: SparkSession,
col: str, format_version
def test_query_filter_only_nulls(spark: SparkSession, col: str,
format_version: int) -> None:
identifier = f"default.arrow_table_v{format_version}_with_only_nulls"
df = spark.table(identifier)
- assert df.where(f"{col} is null").count() == 2, f"Expected 2 row for {col}"
- assert df.where(f"{col} is not null").count() == 0, f"Expected 0 rows for
{col}"
+ assert df.where(f"{col} is null").count() == 2, f"Expected 2 rows for
{col}"
+ assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row for
{col}"
@pytest.mark.integration
diff --git a/tests/integration/test_writes/utils.py
b/tests/integration/test_writes/utils.py
new file mode 100644
index 00000000..792e2518
--- /dev/null
+++ b/tests/integration/test_writes/utils.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint:disable=redefined-outer-name
+from typing import List, Optional
+
+import pyarrow as pa
+
+from pyiceberg.catalog import Catalog
+from pyiceberg.exceptions import NoSuchTableError
+from pyiceberg.partitioning import PartitionSpec
+from pyiceberg.schema import Schema
+from pyiceberg.table import Table
+from pyiceberg.typedef import Properties
+from pyiceberg.types import (
+ BinaryType,
+ BooleanType,
+ DateType,
+ DoubleType,
+ FixedType,
+ FloatType,
+ IntegerType,
+ LongType,
+ NestedField,
+ StringType,
+ TimestampType,
+ TimestamptzType,
+)
+
+TABLE_SCHEMA = Schema(
+ NestedField(field_id=1, name="bool", field_type=BooleanType(),
required=False),
+ NestedField(field_id=2, name="string", field_type=StringType(),
required=False),
+ NestedField(field_id=3, name="string_long", field_type=StringType(),
required=False),
+ NestedField(field_id=4, name="int", field_type=IntegerType(),
required=False),
+ NestedField(field_id=5, name="long", field_type=LongType(),
required=False),
+ NestedField(field_id=6, name="float", field_type=FloatType(),
required=False),
+ NestedField(field_id=7, name="double", field_type=DoubleType(),
required=False),
+ # NestedField(field_id=8, name="time", field_type=TimeType(),
required=False), # Spark does not support time fields
+ NestedField(field_id=8, name="timestamp", field_type=TimestampType(),
required=False),
+ NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(),
required=False),
+ NestedField(field_id=10, name="date", field_type=DateType(),
required=False),
+ # NestedField(field_id=11, name="time", field_type=TimeType(),
required=False),
+ # NestedField(field_id=12, name="uuid", field_type=UuidType(),
required=False),
+ NestedField(field_id=11, name="binary", field_type=BinaryType(),
required=False),
+ NestedField(field_id=12, name="fixed", field_type=FixedType(16),
required=False),
+)
+
+
+def _create_table(
+ session_catalog: Catalog,
+ identifier: str,
+ properties: Properties,
+ data: Optional[List[pa.Table]] = None,
+ partition_spec: Optional[PartitionSpec] = None,
+) -> Table:
+ try:
+ session_catalog.drop_table(identifier=identifier)
+ except NoSuchTableError:
+ pass
+
+ if partition_spec:
+ tbl = session_catalog.create_table(
+ identifier=identifier, schema=TABLE_SCHEMA, properties=properties,
partition_spec=partition_spec
+ )
+ else:
+ tbl = session_catalog.create_table(identifier=identifier,
schema=TABLE_SCHEMA, properties=properties)
+
+ if data:
+ for d in data:
+ tbl.append(d)
+
+ return tbl
diff --git a/tests/table/test_init.py b/tests/table/test_init.py
index f1191295..2bc78f31 100644
--- a/tests/table/test_init.py
+++ b/tests/table/test_init.py
@@ -64,6 +64,7 @@ from pyiceberg.table import (
UpdateSchema,
_apply_table_update,
_check_schema_compatible,
+ _determine_partitions,
_match_deletes_to_data_file,
_TableMetadataUpdateContext,
update_table_metadata,
@@ -82,7 +83,11 @@ from pyiceberg.table.sorting import (
SortField,
SortOrder,
)
-from pyiceberg.transforms import BucketTransform, IdentityTransform
+from pyiceberg.transforms import (
+ BucketTransform,
+ IdentityTransform,
+)
+from pyiceberg.typedef import Record
from pyiceberg.types import (
BinaryType,
BooleanType,
@@ -1139,3 +1144,85 @@ def test_serialize_commit_table_request() -> None:
deserialized_request =
CommitTableRequest.model_validate_json(request.model_dump_json())
assert request == deserialized_request
+
+
+def test_partition_for_demo() -> None:
+ import pyarrow as pa
+
+ test_pa_schema = pa.schema([('year', pa.int64()), ("n_legs", pa.int64()),
("animal", pa.string())])
+ test_schema = Schema(
+ NestedField(field_id=1, name='year', field_type=StringType(),
required=False),
+ NestedField(field_id=2, name='n_legs', field_type=IntegerType(),
required=True),
+ NestedField(field_id=3, name='animal', field_type=StringType(),
required=False),
+ schema_id=1,
+ )
+ test_data = {
+ 'year': [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021],
+ 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
+ 'animal': ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse",
"Horse", "Brittle stars", "Centipede"],
+ }
+ arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
+ partition_spec = PartitionSpec(
+ PartitionField(source_id=2, field_id=1002,
transform=IdentityTransform(), name="n_legs_identity"),
+ PartitionField(source_id=1, field_id=1001,
transform=IdentityTransform(), name="year_identity"),
+ )
+ result = _determine_partitions(partition_spec, test_schema, arrow_table)
+ assert {table_partition.partition_key.partition for table_partition in
result} == {
+ Record(n_legs_identity=2, year_identity=2020),
+ Record(n_legs_identity=100, year_identity=2021),
+ Record(n_legs_identity=4, year_identity=2021),
+ Record(n_legs_identity=4, year_identity=2022),
+ Record(n_legs_identity=2, year_identity=2022),
+ Record(n_legs_identity=5, year_identity=2019),
+ }
+ assert (
+ pa.concat_tables([table_partition.arrow_table_partition for
table_partition in result]).num_rows == arrow_table.num_rows
+ )
+
+
+def test_identity_partition_on_multi_columns() -> None:
+ import pyarrow as pa
+
+ test_pa_schema = pa.schema([('born_year', pa.int64()), ("n_legs",
pa.int64()), ("animal", pa.string())])
+ test_schema = Schema(
+ NestedField(field_id=1, name='born_year', field_type=StringType(),
required=False),
+ NestedField(field_id=2, name='n_legs', field_type=IntegerType(),
required=True),
+ NestedField(field_id=3, name='animal', field_type=StringType(),
required=False),
+ schema_id=1,
+ )
+ # 5 partitions, 6 unique row values, 12 rows
+ test_rows = [
+ (2021, 4, "Dog"),
+ (2022, 4, "Horse"),
+ (2022, 4, "Another Horse"),
+ (2021, 100, "Centipede"),
+ (None, 4, "Kirin"),
+ (2021, None, "Fish"),
+ ] * 2
+ expected = {Record(n_legs_identity=test_rows[i][1],
year_identity=test_rows[i][0]) for i in range(len(test_rows))}
+ partition_spec = PartitionSpec(
+ PartitionField(source_id=2, field_id=1002,
transform=IdentityTransform(), name="n_legs_identity"),
+ PartitionField(source_id=1, field_id=1001,
transform=IdentityTransform(), name="year_identity"),
+ )
+ import random
+
+ # there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all
+ for _ in range(1000):
+ random.shuffle(test_rows)
+ test_data = {
+ 'born_year': [row[0] for row in test_rows],
+ 'n_legs': [row[1] for row in test_rows],
+ 'animal': [row[2] for row in test_rows],
+ }
+ arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)
+
+ result = _determine_partitions(partition_spec, test_schema,
arrow_table)
+
+ assert {table_partition.partition_key.partition for table_partition in
result} == expected
+ concatenated_arrow_table =
pa.concat_tables([table_partition.arrow_table_partition for table_partition in
result])
+ assert concatenated_arrow_table.num_rows == arrow_table.num_rows
+ assert concatenated_arrow_table.sort_by([
+ ('born_year', 'ascending'),
+ ('n_legs', 'ascending'),
+ ('animal', 'ascending'),
+ ]) == arrow_table.sort_by([('born_year', 'ascending'), ('n_legs',
'ascending'), ('animal', 'ascending')])