This is an automated email from the ASF dual-hosted git repository. fokko pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push: new 33c8931c Feature: Write to branches (#941) 33c8931c is described below commit 33c8931c221e952e57bf829abf9b1d88c17baec8 Author: vinjai <vinayakjaiswa...@gmail.com> AuthorDate: Thu Jul 3 04:24:18 2025 +0530 Feature: Write to branches (#941) Fixes: https://github.com/apache/iceberg-python/issues/306 --------- Co-authored-by: Kevin Liu <kevin.jq....@gmail.com> --- pyiceberg/cli/console.py | 6 +- pyiceberg/table/__init__.py | 97 +++++++++++---- pyiceberg/table/update/__init__.py | 8 +- pyiceberg/table/update/snapshot.py | 178 +++++++++++++++++---------- pyiceberg/utils/concurrent.py | 10 +- tests/integration/test_deletes.py | 29 +++++ tests/integration/test_writes/test_writes.py | 160 +++++++++++++++++++++++- tests/table/test_init.py | 32 +++-- 8 files changed, 409 insertions(+), 111 deletions(-) diff --git a/pyiceberg/cli/console.py b/pyiceberg/cli/console.py index 6be4df12..3fbd9c9f 100644 --- a/pyiceberg/cli/console.py +++ b/pyiceberg/cli/console.py @@ -33,7 +33,7 @@ from pyiceberg.catalog import URI, Catalog, load_catalog from pyiceberg.cli.output import ConsoleOutput, JsonOutput, Output from pyiceberg.exceptions import NoSuchNamespaceError, NoSuchPropertyException, NoSuchTableError from pyiceberg.table import TableProperties -from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.refs import SnapshotRef, SnapshotRefType from pyiceberg.utils.properties import property_as_int @@ -417,7 +417,7 @@ def list_refs(ctx: Context, identifier: str, type: str, verbose: bool) -> None: refs = table.refs() if type: type = type.lower() - if type not in {"branch", "tag"}: + if type not in {SnapshotRefType.BRANCH, SnapshotRefType.TAG}: raise ValueError(f"Type must be either branch or tag, got: {type}") relevant_refs = [ @@ -431,7 +431,7 @@ def list_refs(ctx: Context, identifier: str, type: str, verbose: bool) -> None: def _retention_properties(ref: SnapshotRef, table_properties: Dict[str, str]) -> Dict[str, str]: retention_properties = {} - if ref.snapshot_ref_type == "branch": + if ref.snapshot_ref_type == SnapshotRefType.BRANCH: default_min_snapshots_to_keep = property_as_int( table_properties, TableProperties.MIN_SNAPSHOTS_TO_KEEP, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2eebd0e4..07602c9e 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -87,7 +87,7 @@ from pyiceberg.table.metadata import ( from pyiceberg.table.name_mapping import ( NameMapping, ) -from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef from pyiceberg.table.snapshots import ( Snapshot, SnapshotLogEntry, @@ -397,7 +397,7 @@ class Transaction: expr = Or(expr, match_partition_expression) return expr - def _append_snapshot_producer(self, snapshot_properties: Dict[str, str]) -> _FastAppendFiles: + def _append_snapshot_producer(self, snapshot_properties: Dict[str, str], branch: Optional[str]) -> _FastAppendFiles: """Determine the append type based on table properties. Args: @@ -410,7 +410,7 @@ class Transaction: TableProperties.MANIFEST_MERGE_ENABLED, TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT, ) - update_snapshot = self.update_snapshot(snapshot_properties=snapshot_properties) + update_snapshot = self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch) return update_snapshot.merge_append() if manifest_merge_enabled else update_snapshot.fast_append() def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: @@ -430,13 +430,16 @@ class Transaction: name_mapping=self.table_metadata.name_mapping(), ) - def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> UpdateSnapshot: + def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot: """Create a new UpdateSnapshot to produce a new snapshot for the table. Returns: A new UpdateSnapshot """ - return UpdateSnapshot(self, io=self._table.io, snapshot_properties=snapshot_properties) + if branch is None: + branch = MAIN_BRANCH + + return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties) def update_statistics(self) -> UpdateStatistics: """ @@ -447,13 +450,14 @@ class Transaction: """ return UpdateStatistics(transaction=self) - def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None: """ Shorthand API for appending a PyArrow table to a table transaction. Args: df: The Arrow dataframe that will be appended to overwrite the table snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the append operation """ try: import pyarrow as pa @@ -476,7 +480,7 @@ class Transaction: self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) - with self._append_snapshot_producer(snapshot_properties) as append_files: + with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: # skip writing data files if the dataframe is empty if df.shape[0] > 0: data_files = list( @@ -487,7 +491,9 @@ class Transaction: for data_file in data_files: append_files.append_data_file(data_file) - def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def dynamic_partition_overwrite( + self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None + ) -> None: """ Shorthand for overwriting existing partitions with a PyArrow table. @@ -498,6 +504,7 @@ class Transaction: Args: df: The Arrow dataframe that will be used to overwrite the table snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the dynamic partition overwrite operation """ try: import pyarrow as pa @@ -536,9 +543,9 @@ class Transaction: partitions_to_overwrite = {data_file.partition for data_file in data_files} delete_filter = self._build_partition_predicate(partition_records=partitions_to_overwrite) - self.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties) + self.delete(delete_filter=delete_filter, snapshot_properties=snapshot_properties, branch=branch) - with self._append_snapshot_producer(snapshot_properties) as append_files: + with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: append_files.commit_uuid = append_snapshot_commit_uuid for data_file in data_files: append_files.append_data_file(data_file) @@ -549,6 +556,7 @@ class Transaction: overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, + branch: Optional[str] = None, ) -> None: """ Shorthand for adding a table overwrite with a PyArrow table to the transaction. @@ -563,8 +571,9 @@ class Transaction: df: The Arrow dataframe that will be used to overwrite the table overwrite_filter: ALWAYS_TRUE when you overwrite all the data, or a boolean expression in case of a partial overwrite - case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive snapshot_properties: Custom properties to be added to the snapshot summary + case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive + branch: Branch Reference to run the overwrite operation """ try: import pyarrow as pa @@ -589,9 +598,14 @@ class Transaction: if overwrite_filter != AlwaysFalse(): # Only delete when the filter is != AlwaysFalse - self.delete(delete_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties) + self.delete( + delete_filter=overwrite_filter, + case_sensitive=case_sensitive, + snapshot_properties=snapshot_properties, + branch=branch, + ) - with self._append_snapshot_producer(snapshot_properties) as append_files: + with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: # skip writing data files if the dataframe is empty if df.shape[0] > 0: data_files = _dataframe_to_data_files( @@ -605,6 +619,7 @@ class Transaction: delete_filter: Union[str, BooleanExpression], snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, + branch: Optional[str] = None, ) -> None: """ Shorthand for deleting record from a table. @@ -618,6 +633,7 @@ class Transaction: delete_filter: A boolean expression to delete rows from a table snapshot_properties: Custom properties to be added to the snapshot summary case_sensitive: A bool determine if the provided `delete_filter` is case-sensitive + branch: Branch Reference to run the delete operation """ from pyiceberg.io.pyarrow import ( ArrowScan, @@ -634,7 +650,7 @@ class Transaction: if isinstance(delete_filter, str): delete_filter = _parse_row_filter(delete_filter) - with self.update_snapshot(snapshot_properties=snapshot_properties).delete() as delete_snapshot: + with self.update_snapshot(snapshot_properties=snapshot_properties, branch=branch).delete() as delete_snapshot: delete_snapshot.delete_by_predicate(delete_filter, case_sensitive) # Check if there are any files that require an actual rewrite of a data file @@ -642,7 +658,10 @@ class Transaction: bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive) preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter) - files = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive).plan_files() + file_scan = self._scan(row_filter=delete_filter, case_sensitive=case_sensitive) + if branch is not None: + file_scan = file_scan.use_ref(branch) + files = file_scan.plan_files() commit_uuid = uuid.uuid4() counter = itertools.count(0) @@ -684,7 +703,9 @@ class Transaction: ) if len(replaced_files) > 0: - with self.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as overwrite_snapshot: + with self.update_snapshot( + snapshot_properties=snapshot_properties, branch=branch + ).overwrite() as overwrite_snapshot: overwrite_snapshot.commit_uuid = commit_uuid for original_data_file, replaced_data_files in replaced_files: overwrite_snapshot.delete_data_file(original_data_file) @@ -701,6 +722,7 @@ class Transaction: when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True, case_sensitive: bool = True, + branch: Optional[str] = None, ) -> UpsertResult: """Shorthand API for performing an upsert to an iceberg table. @@ -711,6 +733,7 @@ class Transaction: when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table case_sensitive: Bool indicating if the match should be case-sensitive + branch: Branch Reference to run the upsert operation To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids @@ -773,12 +796,18 @@ class Transaction: matched_predicate = upsert_util.create_match_filter(df, join_cols) # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. - matched_iceberg_record_batches = DataScan( + + matched_iceberg_record_batches_scan = DataScan( table_metadata=self.table_metadata, io=self._table.io, row_filter=matched_predicate, case_sensitive=case_sensitive, - ).to_arrow_batch_reader() + ) + + if branch is not None: + matched_iceberg_record_batches_scan = matched_iceberg_record_batches_scan.use_ref(branch) + + matched_iceberg_record_batches = matched_iceberg_record_batches_scan.to_arrow_batch_reader() batches_to_overwrite = [] overwrite_predicates = [] @@ -817,12 +846,13 @@ class Transaction: self.overwrite( rows_to_update, overwrite_filter=Or(*overwrite_predicates) if len(overwrite_predicates) > 1 else overwrite_predicates[0], + branch=branch, ) if when_not_matched_insert_all: insert_row_cnt = len(rows_to_insert) if rows_to_insert: - self.append(rows_to_insert) + self.append(rows_to_insert, branch=branch) return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) @@ -1259,6 +1289,7 @@ class Table: when_matched_update_all: bool = True, when_not_matched_insert_all: bool = True, case_sensitive: bool = True, + branch: Optional[str] = None, ) -> UpsertResult: """Shorthand API for performing an upsert to an iceberg table. @@ -1269,6 +1300,7 @@ class Table: when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table case_sensitive: Bool indicating if the match should be case-sensitive + branch: Branch Reference to run the upsert operation To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids @@ -1301,29 +1333,34 @@ class Table: when_matched_update_all=when_matched_update_all, when_not_matched_insert_all=when_not_matched_insert_all, case_sensitive=case_sensitive, + branch=branch, ) - def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> None: """ Shorthand API for appending a PyArrow table to the table. Args: df: The Arrow dataframe that will be appended to overwrite the table snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the append operation """ with self.transaction() as tx: - tx.append(df=df, snapshot_properties=snapshot_properties) + tx.append(df=df, snapshot_properties=snapshot_properties, branch=branch) - def dynamic_partition_overwrite(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def dynamic_partition_overwrite( + self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None + ) -> None: """Shorthand for dynamic overwriting the table with a PyArrow table. Old partitions are auto detected and replaced with data files created for input arrow table. Args: df: The Arrow dataframe that will be used to overwrite the table snapshot_properties: Custom properties to be added to the snapshot summary + branch: Branch Reference to run the dynamic partition overwrite operation """ with self.transaction() as tx: - tx.dynamic_partition_overwrite(df=df, snapshot_properties=snapshot_properties) + tx.dynamic_partition_overwrite(df=df, snapshot_properties=snapshot_properties, branch=branch) def overwrite( self, @@ -1331,6 +1368,7 @@ class Table: overwrite_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, + branch: Optional[str] = None, ) -> None: """ Shorthand for overwriting the table with a PyArrow table. @@ -1347,10 +1385,15 @@ class Table: or a boolean expression in case of a partial overwrite snapshot_properties: Custom properties to be added to the snapshot summary case_sensitive: A bool determine if the provided `overwrite_filter` is case-sensitive + branch: Branch Reference to run the overwrite operation """ with self.transaction() as tx: tx.overwrite( - df=df, overwrite_filter=overwrite_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties + df=df, + overwrite_filter=overwrite_filter, + case_sensitive=case_sensitive, + snapshot_properties=snapshot_properties, + branch=branch, ) def delete( @@ -1358,6 +1401,7 @@ class Table: delete_filter: Union[BooleanExpression, str] = ALWAYS_TRUE, snapshot_properties: Dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, + branch: Optional[str] = None, ) -> None: """ Shorthand for deleting rows from the table. @@ -1366,9 +1410,12 @@ class Table: delete_filter: The predicate that used to remove rows snapshot_properties: Custom properties to be added to the snapshot summary case_sensitive: A bool determine if the provided `delete_filter` is case-sensitive + branch: Branch Reference to run the delete operation """ with self.transaction() as tx: - tx.delete(delete_filter=delete_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties) + tx.delete( + delete_filter=delete_filter, case_sensitive=case_sensitive, snapshot_properties=snapshot_properties, branch=branch + ) def add_files( self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index 4905c31b..6653f119 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -29,7 +29,7 @@ from pyiceberg.exceptions import CommitFailedException from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table.metadata import SUPPORTED_TABLE_FORMAT_VERSION, TableMetadata, TableMetadataUtil -from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import ( MetadataLogEntry, Snapshot, @@ -139,7 +139,7 @@ class AddSnapshotUpdate(IcebergBaseModel): class SetSnapshotRefUpdate(IcebergBaseModel): action: Literal["set-snapshot-ref"] = Field(default="set-snapshot-ref") ref_name: str = Field(alias="ref-name") - type: Literal["tag", "branch"] + type: Literal[SnapshotRefType.TAG, SnapshotRefType.BRANCH] snapshot_id: int = Field(alias="snapshot-id") max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None)] max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None)] @@ -702,6 +702,10 @@ class AssertRefSnapshotId(ValidatableTableRequirement): def validate(self, base_metadata: Optional[TableMetadata]) -> None: if base_metadata is None: raise CommitFailedException("Requirement failed: current table metadata is missing") + elif len(base_metadata.snapshots) == 0 and self.ref != MAIN_BRANCH: + raise CommitFailedException( + f"Requirement failed: Table has no snapshots and can only be written to the {MAIN_BRANCH} BRANCH." + ) elif snapshot_ref := base_metadata.refs.get(self.ref): ref_type = snapshot_ref.snapshot_ref_type if self.snapshot_id is None: diff --git a/pyiceberg/table/update/snapshot.py b/pyiceberg/table/update/snapshot.py index 66a087f3..3ffb275d 100644 --- a/pyiceberg/table/update/snapshot.py +++ b/pyiceberg/table/update/snapshot.py @@ -56,7 +56,7 @@ from pyiceberg.manifest import ( from pyiceberg.partitioning import ( PartitionSpec, ) -from pyiceberg.table.refs import SnapshotRefType +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRefType from pyiceberg.table.snapshots import ( Operation, Snapshot, @@ -108,6 +108,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]): _manifest_num_counter: itertools.count[int] _deleted_data_files: Set[DataFile] _compression: AvroCompressionCodec + _target_branch = MAIN_BRANCH def __init__( self, @@ -116,16 +117,13 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]): io: FileIO, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, + branch: str = MAIN_BRANCH, ) -> None: super().__init__(transaction) self.commit_uuid = commit_uuid or uuid.uuid4() self._io = io self._operation = operation self._snapshot_id = self._transaction.table_metadata.new_snapshot_id() - # Since we only support the main branch for now - self._parent_snapshot_id = ( - snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.current_snapshot()) else None - ) self._added_data_files = [] self._deleted_data_files = set() self.snapshot_properties = snapshot_properties @@ -135,6 +133,20 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]): self._compression = self._transaction.table_metadata.properties.get( # type: ignore TableProperties.WRITE_AVRO_COMPRESSION, TableProperties.WRITE_AVRO_COMPRESSION_DEFAULT ) + self._target_branch = self._validate_target_branch(branch=branch) + self._parent_snapshot_id = ( + snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None + ) + + def _validate_target_branch(self, branch: str) -> str: + # Default is already set to MAIN_BRANCH. So branch name can't be None. + if branch is None: + raise ValueError("Invalid branch name: null") + if branch in self._transaction.table_metadata.refs: + ref = self._transaction.table_metadata.refs[branch] + if ref.snapshot_ref_type != SnapshotRefType.BRANCH: + raise ValueError(f"{branch} is a tag, not a branch. Tags cannot be targets for producing snapshots") + return branch def append_data_file(self, data_file: DataFile) -> _SnapshotProducer[U]: self._added_data_files.append(data_file) @@ -284,10 +296,20 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]): ( AddSnapshotUpdate(snapshot=snapshot), SetSnapshotRefUpdate( - snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch" + snapshot_id=self._snapshot_id, + parent_snapshot_id=self._parent_snapshot_id, + ref_name=self._target_branch, + type=SnapshotRefType.BRANCH, + ), + ), + ( + AssertRefSnapshotId( + snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id + if self._target_branch in self._transaction.table_metadata.refs + else None, + ref=self._target_branch, ), ), - (AssertRefSnapshotId(snapshot_id=self._transaction.table_metadata.current_snapshot_id, ref="main"),), ) @property @@ -335,10 +357,11 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]): operation: Operation, transaction: Transaction, io: FileIO, + branch: str, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, ): - super().__init__(operation, transaction, io, commit_uuid, snapshot_properties) + super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch) self._predicate = AlwaysFalse() self._case_sensitive = True @@ -398,47 +421,53 @@ class _DeleteFiles(_SnapshotProducer["_DeleteFiles"]): total_deleted_entries = [] partial_rewrites_needed = False self._deleted_data_files = set() - if snapshot := self._transaction.table_metadata.current_snapshot(): - for manifest_file in snapshot.manifests(io=self._io): - if manifest_file.content == ManifestContent.DATA: - if not manifest_evaluators[manifest_file.partition_spec_id](manifest_file): - # If the manifest isn't relevant, we can just keep it in the manifest-list - existing_manifests.append(manifest_file) - else: - # It is relevant, let's check out the content - deleted_entries = [] - existing_entries = [] - for entry in manifest_file.fetch_manifest_entry(io=self._io, discard_deleted=True): - if strict_metrics_evaluator(entry.data_file) == ROWS_MUST_MATCH: - # Based on the metadata, it can be dropped right away - deleted_entries.append(_copy_with_new_status(entry, ManifestEntryStatus.DELETED)) - self._deleted_data_files.add(entry.data_file) - else: - # Based on the metadata, we cannot determine if it can be deleted - existing_entries.append(_copy_with_new_status(entry, ManifestEntryStatus.EXISTING)) - if inclusive_metrics_evaluator(entry.data_file) != ROWS_MIGHT_NOT_MATCH: - partial_rewrites_needed = True - - if len(deleted_entries) > 0: - total_deleted_entries += deleted_entries - - # Rewrite the manifest - if len(existing_entries) > 0: - with write_manifest( - format_version=self._transaction.table_metadata.format_version, - spec=self._transaction.table_metadata.specs()[manifest_file.partition_spec_id], - schema=self._transaction.table_metadata.schema(), - output_file=self.new_manifest_output(), - snapshot_id=self._snapshot_id, - avro_compression=self._compression, - ) as writer: - for existing_entry in existing_entries: - writer.add_entry(existing_entry) - existing_manifests.append(writer.to_manifest_file()) - else: + + # Determine the snapshot to read manifests from for deletion + # Should be the current tip of the _target_branch + parent_snapshot_id_for_delete_source = self._parent_snapshot_id + if parent_snapshot_id_for_delete_source is not None: + snapshot = self._transaction.table_metadata.snapshot_by_id(parent_snapshot_id_for_delete_source) + if snapshot: # Ensure snapshot is found + for manifest_file in snapshot.manifests(io=self._io): + if manifest_file.content == ManifestContent.DATA: + if not manifest_evaluators[manifest_file.partition_spec_id](manifest_file): + # If the manifest isn't relevant, we can just keep it in the manifest-list existing_manifests.append(manifest_file) - else: - existing_manifests.append(manifest_file) + else: + # It is relevant, let's check out the content + deleted_entries = [] + existing_entries = [] + for entry in manifest_file.fetch_manifest_entry(io=self._io, discard_deleted=True): + if strict_metrics_evaluator(entry.data_file) == ROWS_MUST_MATCH: + # Based on the metadata, it can be dropped right away + deleted_entries.append(_copy_with_new_status(entry, ManifestEntryStatus.DELETED)) + self._deleted_data_files.add(entry.data_file) + else: + # Based on the metadata, we cannot determine if it can be deleted + existing_entries.append(_copy_with_new_status(entry, ManifestEntryStatus.EXISTING)) + if inclusive_metrics_evaluator(entry.data_file) != ROWS_MIGHT_NOT_MATCH: + partial_rewrites_needed = True + + if len(deleted_entries) > 0: + total_deleted_entries += deleted_entries + + # Rewrite the manifest + if len(existing_entries) > 0: + with write_manifest( + format_version=self._transaction.table_metadata.format_version, + spec=self._transaction.table_metadata.specs()[manifest_file.partition_spec_id], + schema=self._transaction.table_metadata.schema(), + output_file=self.new_manifest_output(), + snapshot_id=self._snapshot_id, + avro_compression=self._compression, + ) as writer: + for existing_entry in existing_entries: + writer.add_entry(existing_entry) + existing_manifests.append(writer.to_manifest_file()) + else: + existing_manifests.append(manifest_file) + else: + existing_manifests.append(manifest_file) return existing_manifests, total_deleted_entries, partial_rewrites_needed @@ -498,12 +527,13 @@ class _MergeAppendFiles(_FastAppendFiles): operation: Operation, transaction: Transaction, io: FileIO, + branch: str, commit_uuid: Optional[uuid.UUID] = None, snapshot_properties: Dict[str, str] = EMPTY_DICT, ) -> None: from pyiceberg.table import TableProperties - super().__init__(operation, transaction, io, commit_uuid, snapshot_properties) + super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch) self._target_size_bytes = property_as_int( self._transaction.table_metadata.properties, TableProperties.MANIFEST_TARGET_SIZE_BYTES, @@ -549,7 +579,7 @@ class _OverwriteFiles(_SnapshotProducer["_OverwriteFiles"]): """Determine if there are any existing manifest files.""" existing_files = [] - if snapshot := self._transaction.table_metadata.current_snapshot(): + if snapshot := self._transaction.table_metadata.snapshot_by_name(name=self._target_branch): for manifest_file in snapshot.manifests(io=self._io): entries = manifest_file.fetch_manifest_entry(io=self._io, discard_deleted=True) found_deleted_data_files = [entry.data_file for entry in entries if entry.data_file in self._deleted_data_files] @@ -567,19 +597,17 @@ class _OverwriteFiles(_SnapshotProducer["_OverwriteFiles"]): snapshot_id=self._snapshot_id, avro_compression=self._compression, ) as writer: - [ - writer.add_entry( - ManifestEntry.from_args( - status=ManifestEntryStatus.EXISTING, - snapshot_id=entry.snapshot_id, - sequence_number=entry.sequence_number, - file_sequence_number=entry.file_sequence_number, - data_file=entry.data_file, + for entry in entries: + if entry.data_file not in found_deleted_data_files: + writer.add_entry( + ManifestEntry.from_args( + status=ManifestEntryStatus.EXISTING, + snapshot_id=entry.snapshot_id, + sequence_number=entry.sequence_number, + file_sequence_number=entry.file_sequence_number, + data_file=entry.data_file, + ) ) - ) - for entry in entries - if entry.data_file not in found_deleted_data_files - ] existing_files.append(writer.to_manifest_file()) return existing_files @@ -620,31 +648,48 @@ class _OverwriteFiles(_SnapshotProducer["_OverwriteFiles"]): class UpdateSnapshot: _transaction: Transaction _io: FileIO + _branch: str _snapshot_properties: Dict[str, str] - def __init__(self, transaction: Transaction, io: FileIO, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: + def __init__( + self, + transaction: Transaction, + io: FileIO, + branch: str, + snapshot_properties: Dict[str, str] = EMPTY_DICT, + ) -> None: self._transaction = transaction self._io = io self._snapshot_properties = snapshot_properties + self._branch = branch def fast_append(self) -> _FastAppendFiles: return _FastAppendFiles( - operation=Operation.APPEND, transaction=self._transaction, io=self._io, snapshot_properties=self._snapshot_properties + operation=Operation.APPEND, + transaction=self._transaction, + io=self._io, + branch=self._branch, + snapshot_properties=self._snapshot_properties, ) def merge_append(self) -> _MergeAppendFiles: return _MergeAppendFiles( - operation=Operation.APPEND, transaction=self._transaction, io=self._io, snapshot_properties=self._snapshot_properties + operation=Operation.APPEND, + transaction=self._transaction, + io=self._io, + branch=self._branch, + snapshot_properties=self._snapshot_properties, ) def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles: return _OverwriteFiles( commit_uuid=commit_uuid, operation=Operation.OVERWRITE - if self._transaction.table_metadata.current_snapshot() is not None + if self._transaction.table_metadata.snapshot_by_name(name=self._branch) is not None else Operation.APPEND, transaction=self._transaction, io=self._io, + branch=self._branch, snapshot_properties=self._snapshot_properties, ) @@ -653,6 +698,7 @@ class UpdateSnapshot: operation=Operation.DELETE, transaction=self._transaction, io=self._io, + branch=self._branch, snapshot_properties=self._snapshot_properties, ) diff --git a/pyiceberg/utils/concurrent.py b/pyiceberg/utils/concurrent.py index 805599bf..751cbd9b 100644 --- a/pyiceberg/utils/concurrent.py +++ b/pyiceberg/utils/concurrent.py @@ -25,6 +25,11 @@ from pyiceberg.utils.config import Config class ExecutorFactory: _instance: Optional[Executor] = None + @staticmethod + def max_workers() -> Optional[int]: + """Return the max number of workers configured.""" + return Config().get_int("max-workers") + @staticmethod def get_or_create() -> Executor: """Return the same executor in each call.""" @@ -33,8 +38,3 @@ class ExecutorFactory: ExecutorFactory._instance = ThreadPoolExecutor(max_workers=max_workers) return ExecutorFactory._instance - - @staticmethod - def max_workers() -> Optional[int]: - """Return the max number of workers configured.""" - return Config().get_int("max-workers") diff --git a/tests/integration/test_deletes.py b/tests/integration/test_deletes.py index 527f6596..abf8502a 100644 --- a/tests/integration/test_deletes.py +++ b/tests/integration/test_deletes.py @@ -894,3 +894,32 @@ def test_overwrite_with_filter_case_insensitive(test_table: Table) -> None: test_table.overwrite(df=new_table, overwrite_filter=f"Idx == {record_to_overwrite['idx']}", case_sensitive=False) assert record_to_overwrite not in test_table.scan().to_arrow().to_pylist() assert new_record_to_insert in test_table.scan().to_arrow().to_pylist() + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +@pytest.mark.filterwarnings("ignore:Delete operation did not match any records") +def test_delete_on_empty_table(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None: + identifier = f"default.test_delete_on_empty_table_{format_version}" + + run_spark_commands( + spark, + [ + f"DROP TABLE IF EXISTS {identifier}", + f""" + CREATE TABLE {identifier} ( + volume int + ) + USING iceberg + TBLPROPERTIES('format-version' = {format_version}) + """, + ], + ) + + tbl = session_catalog.load_table(identifier) + + # Perform a delete operation on the empty table + tbl.delete(AlwaysTrue()) + + # Assert that no new snapshot was created because no rows were deleted + assert len(tbl.snapshots()) == 0 diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 033a9f7c..b66601f6 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -41,12 +41,13 @@ from pytest_mock.plugin import MockerFixture from pyiceberg.catalog import Catalog, load_catalog from pyiceberg.catalog.hive import HiveCatalog from pyiceberg.catalog.sql import SqlCatalog -from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.exceptions import CommitFailedException, NoSuchTableError from pyiceberg.expressions import And, EqualTo, GreaterThanOrEqual, In, LessThan, Not from pyiceberg.io.pyarrow import _dataframe_to_data_files from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import TableProperties +from pyiceberg.table.refs import MAIN_BRANCH from pyiceberg.table.sorting import SortDirection, SortField, SortOrder from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform from pyiceberg.types import ( @@ -1856,3 +1857,160 @@ def test_avro_compression_codecs(session_catalog: Catalog, arrow_table_with_null with tbl.io.new_input(current_snapshot.manifest_list).open() as f: reader = fastavro.reader(f) assert reader.codec == "null" + + +@pytest.mark.integration +def test_append_to_non_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.test_non_existing_branch" + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, []) + with pytest.raises( + CommitFailedException, match=f"Table has no snapshots and can only be written to the {MAIN_BRANCH} BRANCH." + ): + tbl.append(arrow_table_with_null, branch="non_existing_branch") + + +@pytest.mark.integration +def test_append_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.test_existing_branch_append" + branch = "existing_branch" + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) + + assert tbl.metadata.current_snapshot_id is not None + + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() + tbl.append(arrow_table_with_null, branch=branch) + + assert len(tbl.scan().use_ref(branch).to_arrow()) == 6 + assert len(tbl.scan().to_arrow()) == 3 + branch_snapshot = tbl.metadata.snapshot_by_name(branch) + assert branch_snapshot is not None + main_snapshot = tbl.metadata.snapshot_by_name("main") + assert main_snapshot is not None + assert branch_snapshot.parent_snapshot_id == main_snapshot.snapshot_id + + +@pytest.mark.integration +def test_delete_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.test_existing_branch_delete" + branch = "existing_branch" + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) + + assert tbl.metadata.current_snapshot_id is not None + + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() + tbl.delete(delete_filter="int = 9", branch=branch) + + assert len(tbl.scan().use_ref(branch).to_arrow()) == 2 + assert len(tbl.scan().to_arrow()) == 3 + branch_snapshot = tbl.metadata.snapshot_by_name(branch) + assert branch_snapshot is not None + main_snapshot = tbl.metadata.snapshot_by_name("main") + assert main_snapshot is not None + assert branch_snapshot.parent_snapshot_id == main_snapshot.snapshot_id + + +@pytest.mark.integration +def test_overwrite_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.test_existing_branch_overwrite" + branch = "existing_branch" + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) + + assert tbl.metadata.current_snapshot_id is not None + + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() + tbl.overwrite(arrow_table_with_null, branch=branch) + + assert len(tbl.scan().use_ref(branch).to_arrow()) == 3 + assert len(tbl.scan().to_arrow()) == 3 + branch_snapshot = tbl.metadata.snapshot_by_name(branch) + assert branch_snapshot is not None and branch_snapshot.parent_snapshot_id is not None + delete_snapshot = tbl.metadata.snapshot_by_id(branch_snapshot.parent_snapshot_id) + assert delete_snapshot is not None + main_snapshot = tbl.metadata.snapshot_by_name("main") + assert main_snapshot is not None + assert ( + delete_snapshot.parent_snapshot_id == main_snapshot.snapshot_id + ) # Currently overwrite is a delete followed by an append operation + + +@pytest.mark.integration +def test_intertwined_branch_writes(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.test_intertwined_branch_operations" + branch1 = "existing_branch_1" + branch2 = "existing_branch_2" + + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) + + assert tbl.metadata.current_snapshot_id is not None + + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch1).commit() + + tbl.delete("int = 9", branch=branch1) + + tbl.append(arrow_table_with_null) + + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch2).commit() + + tbl.overwrite(arrow_table_with_null, branch=branch2) + + assert len(tbl.scan().use_ref(branch1).to_arrow()) == 2 + assert len(tbl.scan().use_ref(branch2).to_arrow()) == 3 + assert len(tbl.scan().to_arrow()) == 6 + + +@pytest.mark.integration +def test_branch_spark_write_py_read(session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table) -> None: + # Initialize table with branch + identifier = "default.test_branch_spark_write_py_read" + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) + branch = "existing_spark_branch" + + # Create branch in Spark + spark.sql(f"ALTER TABLE {identifier} CREATE BRANCH {branch}") + + # Spark Write + spark.sql( + f""" + DELETE FROM {identifier}.branch_{branch} + WHERE int = 9 + """ + ) + + # Refresh table to get new refs + tbl.refresh() + + # Python Read + assert len(tbl.scan().to_arrow()) == 3 + assert len(tbl.scan().use_ref(branch).to_arrow()) == 2 + + +@pytest.mark.integration +def test_branch_py_write_spark_read(session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table) -> None: + # Initialize table with branch + identifier = "default.test_branch_py_write_spark_read" + tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) + branch = "existing_py_branch" + + assert tbl.metadata.current_snapshot_id is not None + + # Create branch + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() + + # Python Write + tbl.delete("int = 9", branch=branch) + + # Spark Read + main_df = spark.sql( + f""" + SELECT * + FROM {identifier} + """ + ) + branch_df = spark.sql( + f""" + SELECT * + FROM {identifier}.branch_{branch} + """ + ) + assert main_df.count() == 3 + assert branch_df.count() == 2 diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 6165dade..89524a86 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -50,7 +50,7 @@ from pyiceberg.table import ( _match_deletes_to_data_file, ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2, _generate_snapshot_id -from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import ( MetadataLogEntry, Operation, @@ -1000,28 +1000,42 @@ def test_assert_table_uuid(table_v2: Table) -> None: def test_assert_ref_snapshot_id(table_v2: Table) -> None: base_metadata = table_v2.metadata - AssertRefSnapshotId(ref="main", snapshot_id=base_metadata.current_snapshot_id).validate(base_metadata) + AssertRefSnapshotId(ref=MAIN_BRANCH, snapshot_id=base_metadata.current_snapshot_id).validate(base_metadata) with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): - AssertRefSnapshotId(ref="main", snapshot_id=1).validate(None) + AssertRefSnapshotId(ref=MAIN_BRANCH, snapshot_id=1).validate(None) with pytest.raises( CommitFailedException, - match="Requirement failed: branch main was created concurrently", + match=f"Requirement failed: branch {MAIN_BRANCH} was created concurrently", ): - AssertRefSnapshotId(ref="main", snapshot_id=None).validate(base_metadata) + AssertRefSnapshotId(ref=MAIN_BRANCH, snapshot_id=None).validate(base_metadata) with pytest.raises( CommitFailedException, - match="Requirement failed: branch main has changed: expected id 1, found 3055729675574597004", + match=f"Requirement failed: branch {MAIN_BRANCH} has changed: expected id 1, found 3055729675574597004", ): - AssertRefSnapshotId(ref="main", snapshot_id=1).validate(base_metadata) + AssertRefSnapshotId(ref=MAIN_BRANCH, snapshot_id=1).validate(base_metadata) + + non_existing_ref = "not_exist_branch_or_tag" + assert table_v2.refs().get("not_exist_branch_or_tag") is None + + with pytest.raises( + CommitFailedException, + match=f"Requirement failed: branch or tag {non_existing_ref} is missing, expected 1", + ): + AssertRefSnapshotId(ref=non_existing_ref, snapshot_id=1).validate(base_metadata) + + # existing Tag in metadata: test + ref_tag = table_v2.refs().get("test") + assert ref_tag is not None + assert ref_tag.snapshot_ref_type == SnapshotRefType.TAG, "TAG test should be present in table to be tested" with pytest.raises( CommitFailedException, - match="Requirement failed: branch or tag not_exist is missing, expected 1", + match="Requirement failed: tag test has changed: expected id 3055729675574597004, found 3051729675574597004", ): - AssertRefSnapshotId(ref="not_exist", snapshot_id=1).validate(base_metadata) + AssertRefSnapshotId(ref="test", snapshot_id=3055729675574597004).validate(base_metadata) def test_assert_last_assigned_field_id(table_v2: Table) -> None: