This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new a8923099 Add CreateTableTransaction API and implement it in Glue and
Rest (#498)
a8923099 is described below
commit a892309936effa7ec575195ad3be70193e82d704
Author: Honah J <[email protected]>
AuthorDate: Thu Apr 4 01:02:33 2024 -0400
Add CreateTableTransaction API and implement it in Glue and Rest (#498)
---
mkdocs/docs/api.md | 19 +++
pyiceberg/catalog/__init__.py | 298 ++++++++++++++++++++++-----------
pyiceberg/catalog/dynamodb.py | 4 +-
pyiceberg/catalog/glue.py | 136 +++++++++------
pyiceberg/catalog/hive.py | 4 +-
pyiceberg/catalog/noop.py | 18 ++
pyiceberg/catalog/rest.py | 76 ++++++++-
pyiceberg/catalog/sql.py | 4 +-
pyiceberg/table/__init__.py | 160 +++++++++++++++---
pyiceberg/table/metadata.py | 2 +-
tests/catalog/integration_test_glue.py | 92 +++++++++-
tests/catalog/test_base.py | 3 +-
tests/catalog/test_glue.py | 48 ++++++
tests/integration/test_writes.py | 54 ++++++
14 files changed, 723 insertions(+), 195 deletions(-)
diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md
index 828dd186..c8620af7 100644
--- a/mkdocs/docs/api.md
+++ b/mkdocs/docs/api.md
@@ -165,6 +165,25 @@ catalog.create_table(
)
```
+To create a table with some subsequent changes atomically in a transaction:
+
+```python
+with catalog.create_table_transaction(
+ identifier="docs_example.bids",
+ schema=schema,
+ location="s3://pyiceberg",
+ partition_spec=partition_spec,
+ sort_order=sort_order,
+) as txn:
+ with txn.update_schema() as update_schema:
+ update_schema.add_column(path="new_column", field_type=StringType())
+
+ with txn.update_spec() as update_spec:
+ update_spec.add_identity("symbol")
+
+ txn.set_properties(test_a="test_aa", test_b="test_b", test_c="test_c")
+```
+
## Load a table
### Catalog table
diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py
index f2b46fcd..f104aa94 100644
--- a/pyiceberg/catalog/__init__.py
+++ b/pyiceberg/catalog/__init__.py
@@ -45,9 +45,11 @@ from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import (
CommitTableRequest,
CommitTableResponse,
+ CreateTableTransaction,
+ StagedTable,
Table,
)
-from pyiceberg.table.metadata import TableMetadata
+from pyiceberg.table.metadata import TableMetadata, TableMetadataV1,
new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import (
EMPTY_DICT,
@@ -285,9 +287,6 @@ class Catalog(ABC):
self.name = name
self.properties = properties
- def _load_file_io(self, properties: Properties = EMPTY_DICT, location:
Optional[str] = None) -> FileIO:
- return load_file_io({**self.properties, **properties}, location)
-
@abstractmethod
def create_table(
self,
@@ -315,6 +314,30 @@ class Catalog(ABC):
TableAlreadyExistsError: If a table with the name already exists.
"""
+ @abstractmethod
+ def create_table_transaction(
+ self,
+ identifier: Union[str, Identifier],
+ schema: Union[Schema, "pa.Schema"],
+ location: Optional[str] = None,
+ partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
+ sort_order: SortOrder = UNSORTED_SORT_ORDER,
+ properties: Properties = EMPTY_DICT,
+ ) -> CreateTableTransaction:
+ """Create a CreateTableTransaction.
+
+ Args:
+ identifier (str | Identifier): Table identifier.
+ schema (Schema): Table's schema.
+ location (str | None): Location for the table. Optional Argument.
+ partition_spec (PartitionSpec): PartitionSpec for the table.
+ sort_order (SortOrder): SortOrder for the table.
+ properties (Properties): Table properties that can be a string
based dictionary.
+
+ Returns:
+ CreateTableTransaction: createTableTransaction instance.
+ """
+
def create_table_if_not_exists(
self,
identifier: Union[str, Identifier],
@@ -360,6 +383,17 @@ class Catalog(ABC):
NoSuchTableError: If a table with the name does not exist.
"""
+ @abstractmethod
+ def table_exists(self, identifier: Union[str, Identifier]) -> bool:
+ """Check if a table exists.
+
+ Args:
+ identifier (str | Identifier): Table identifier.
+
+ Returns:
+ bool: True if the table exists, False otherwise.
+ """
+
@abstractmethod
def register_table(self, identifier: Union[str, Identifier],
metadata_location: str) -> Table:
"""Register a new table using existing metadata.
@@ -386,6 +420,19 @@ class Catalog(ABC):
NoSuchTableError: If a table with the name does not exist.
"""
+ @abstractmethod
+ def purge_table(self, identifier: Union[str, Identifier]) -> None:
+ """Drop a table and purge all data and metadata files.
+
+ Note: This method only logs warning rather than raise exception when
encountering file deletion failure.
+
+ Args:
+ identifier (str | Identifier): Table identifier.
+
+ Raises:
+ NoSuchTableError: If a table with the name does not exist, or the
identifier is invalid.
+ """
+
@abstractmethod
def rename_table(self, from_identifier: Union[str, Identifier],
to_identifier: Union[str, Identifier]) -> Table:
"""Rename a fully classified table name.
@@ -501,6 +548,20 @@ class Catalog(ABC):
ValueError: If removals and updates have overlapping keys.
"""
+ def identifier_to_tuple_without_catalog(self, identifier: Union[str,
Identifier]) -> Identifier:
+ """Convert an identifier to a tuple and drop this catalog's name from
the first element.
+
+ Args:
+ identifier (str | Identifier): Table identifier.
+
+ Returns:
+ Identifier: a tuple of strings with this catalog's name removed
+ """
+ identifier_tuple = Catalog.identifier_to_tuple(identifier)
+ if len(identifier_tuple) >= 3 and identifier_tuple[0] == self.name:
+ identifier_tuple = identifier_tuple[1:]
+ return identifier_tuple
+
@staticmethod
def identifier_to_tuple(identifier: Union[str, Identifier]) -> Identifier:
"""Parse an identifier to a tuple.
@@ -539,46 +600,6 @@ class Catalog(ABC):
"""
return Catalog.identifier_to_tuple(identifier)[:-1]
- @staticmethod
- def _check_for_overlap(removals: Optional[Set[str]], updates: Properties)
-> None:
- if updates and removals:
- overlap = set(removals) & set(updates.keys())
- if overlap:
- raise ValueError(f"Updates and deletes have an overlap:
{overlap}")
-
- @staticmethod
- def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) ->
Schema:
- if isinstance(schema, Schema):
- return schema
- try:
- import pyarrow as pa
-
- from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs,
visit_pyarrow
-
- if isinstance(schema, pa.Schema):
- schema: Schema = visit_pyarrow(schema,
_ConvertToIcebergWithoutIDs()) # type: ignore
- return schema
- except ModuleNotFoundError:
- pass
- raise ValueError(f"{type(schema)=}, but it must be
pyiceberg.schema.Schema or pyarrow.Schema")
-
- def _resolve_table_location(self, location: Optional[str], database_name:
str, table_name: str) -> str:
- if not location:
- return self._get_default_warehouse_location(database_name,
table_name)
- return location
-
- def _get_default_warehouse_location(self, database_name: str, table_name:
str) -> str:
- database_properties = self.load_namespace_properties(database_name)
- if database_location := database_properties.get(LOCATION):
- database_location = database_location.rstrip("/")
- return f"{database_location}/{table_name}"
-
- if warehouse_path := self.properties.get(WAREHOUSE_LOCATION):
- warehouse_path = warehouse_path.rstrip("/")
- return f"{warehouse_path}/{database_name}.db/{table_name}"
-
- raise ValueError("No default path is set, please specify a location
when creating a table")
-
@staticmethod
def identifier_to_database(
identifier: Union[str, Identifier], err: Union[Type[ValueError],
Type[NoSuchNamespaceError]] = ValueError
@@ -600,31 +621,52 @@ class Catalog(ABC):
return tuple_identifier[0], tuple_identifier[1]
- def identifier_to_tuple_without_catalog(self, identifier: Union[str,
Identifier]) -> Identifier:
- """Convert an identifier to a tuple and drop this catalog's name from
the first element.
+ def _load_file_io(self, properties: Properties = EMPTY_DICT, location:
Optional[str] = None) -> FileIO:
+ return load_file_io({**self.properties, **properties}, location)
- Args:
- identifier (str | Identifier): Table identifier.
+ @staticmethod
+ def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) ->
Schema:
+ if isinstance(schema, Schema):
+ return schema
+ try:
+ import pyarrow as pa
- Returns:
- Identifier: a tuple of strings with this catalog's name removed
- """
- identifier_tuple = Catalog.identifier_to_tuple(identifier)
- if len(identifier_tuple) >= 3 and identifier_tuple[0] == self.name:
- identifier_tuple = identifier_tuple[1:]
- return identifier_tuple
+ from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs,
visit_pyarrow
- def purge_table(self, identifier: Union[str, Identifier]) -> None:
- """Drop a table and purge all data and metadata files.
+ if isinstance(schema, pa.Schema):
+ schema: Schema = visit_pyarrow(schema,
_ConvertToIcebergWithoutIDs()) # type: ignore
+ return schema
+ except ModuleNotFoundError:
+ pass
+ raise ValueError(f"{type(schema)=}, but it must be
pyiceberg.schema.Schema or pyarrow.Schema")
- Note: This method only logs warning rather than raise exception when
encountering file deletion failure.
+ def __repr__(self) -> str:
+ """Return the string representation of the Catalog class."""
+ return f"{self.name} ({self.__class__})"
- Args:
- identifier (str | Identifier): Table identifier.
- Raises:
- NoSuchTableError: If a table with the name does not exist, or the
identifier is invalid.
- """
+class MetastoreCatalog(Catalog, ABC):
+ def create_table_transaction(
+ self,
+ identifier: Union[str, Identifier],
+ schema: Union[Schema, "pa.Schema"],
+ location: Optional[str] = None,
+ partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
+ sort_order: SortOrder = UNSORTED_SORT_ORDER,
+ properties: Properties = EMPTY_DICT,
+ ) -> CreateTableTransaction:
+ return CreateTableTransaction(
+ self._create_staged_table(identifier, schema, location,
partition_spec, sort_order, properties)
+ )
+
+ def table_exists(self, identifier: Union[str, Identifier]) -> bool:
+ try:
+ self.load_table(identifier)
+ return True
+ except NoSuchTableError:
+ return False
+
+ def purge_table(self, identifier: Union[str, Identifier]) -> None:
identifier_tuple = self.identifier_to_tuple_without_catalog(identifier)
table = self.load_table(identifier_tuple)
self.drop_table(identifier_tuple)
@@ -646,12 +688,88 @@ class Catalog(ABC):
delete_files(io, prev_metadata_files, PREVIOUS_METADATA)
delete_files(io, {table.metadata_location}, METADATA)
- def table_exists(self, identifier: Union[str, Identifier]) -> bool:
- try:
- self.load_table(identifier)
- return True
- except NoSuchTableError:
- return False
+ def _create_staged_table(
+ self,
+ identifier: Union[str, Identifier],
+ schema: Union[Schema, "pa.Schema"],
+ location: Optional[str] = None,
+ partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
+ sort_order: SortOrder = UNSORTED_SORT_ORDER,
+ properties: Properties = EMPTY_DICT,
+ ) -> StagedTable:
+ """Create a table and return the table instance without committing the
changes.
+
+ Args:
+ identifier (str | Identifier): Table identifier.
+ schema (Schema): Table's schema.
+ location (str | None): Location for the table. Optional Argument.
+ partition_spec (PartitionSpec): PartitionSpec for the table.
+ sort_order (SortOrder): SortOrder for the table.
+ properties (Properties): Table properties that can be a string
based dictionary.
+
+ Returns:
+ StagedTable: the created staged table instance.
+ """
+ schema: Schema = self._convert_schema_if_needed(schema) # type: ignore
+
+ database_name, table_name =
self.identifier_to_database_and_table(identifier)
+
+ location = self._resolve_table_location(location, database_name,
table_name)
+ metadata_location = self._get_metadata_location(location=location)
+ metadata = new_table_metadata(
+ location=location, schema=schema, partition_spec=partition_spec,
sort_order=sort_order, properties=properties
+ )
+ io = load_file_io(properties=self.properties,
location=metadata_location)
+ return StagedTable(
+ identifier=(self.name, database_name, table_name),
+ metadata=metadata,
+ metadata_location=metadata_location,
+ io=io,
+ catalog=self,
+ )
+
+ def _get_updated_props_and_update_summary(
+ self, current_properties: Properties, removals: Optional[Set[str]],
updates: Properties
+ ) -> Tuple[PropertiesUpdateSummary, Properties]:
+ self._check_for_overlap(updates=updates, removals=removals)
+ updated_properties = dict(current_properties)
+
+ removed: Set[str] = set()
+ updated: Set[str] = set()
+
+ if removals:
+ for key in removals:
+ if key in updated_properties:
+ updated_properties.pop(key)
+ removed.add(key)
+ if updates:
+ for key, value in updates.items():
+ updated_properties[key] = value
+ updated.add(key)
+
+ expected_to_change = (removals or set()).difference(removed)
+ properties_update_summary = PropertiesUpdateSummary(
+ removed=list(removed or []), updated=list(updated or []),
missing=list(expected_to_change)
+ )
+
+ return properties_update_summary, updated_properties
+
+ def _resolve_table_location(self, location: Optional[str], database_name:
str, table_name: str) -> str:
+ if not location:
+ return self._get_default_warehouse_location(database_name,
table_name)
+ return location
+
+ def _get_default_warehouse_location(self, database_name: str, table_name:
str) -> str:
+ database_properties = self.load_namespace_properties(database_name)
+ if database_location := database_properties.get(LOCATION):
+ database_location = database_location.rstrip("/")
+ return f"{database_location}/{table_name}"
+
+ if warehouse_path := self.properties.get(WAREHOUSE_LOCATION):
+ warehouse_path = warehouse_path.rstrip("/")
+ return f"{warehouse_path}/{database_name}.db/{table_name}"
+
+ raise ValueError("No default path is set, please specify a location
when creating a table")
@staticmethod
def _write_metadata(metadata: TableMetadata, io: FileIO, metadata_path:
str) -> None:
@@ -691,32 +809,22 @@ class Catalog(ABC):
else:
return -1
- def _get_updated_props_and_update_summary(
- self, current_properties: Properties, removals: Optional[Set[str]],
updates: Properties
- ) -> Tuple[PropertiesUpdateSummary, Properties]:
- self._check_for_overlap(updates=updates, removals=removals)
- updated_properties = dict(current_properties)
-
- removed: Set[str] = set()
- updated: Set[str] = set()
-
- if removals:
- for key in removals:
- if key in updated_properties:
- updated_properties.pop(key)
- removed.add(key)
- if updates:
- for key, value in updates.items():
- updated_properties[key] = value
- updated.add(key)
+ @staticmethod
+ def _check_for_overlap(removals: Optional[Set[str]], updates: Properties)
-> None:
+ if updates and removals:
+ overlap = set(removals) & set(updates.keys())
+ if overlap:
+ raise ValueError(f"Updates and deletes have an overlap:
{overlap}")
- expected_to_change = (removals or set()).difference(removed)
- properties_update_summary = PropertiesUpdateSummary(
- removed=list(removed or []), updated=list(updated or []),
missing=list(expected_to_change)
- )
+ @staticmethod
+ def _empty_table_metadata() -> TableMetadata:
+ """Return an empty TableMetadata instance.
- return properties_update_summary, updated_properties
+ It is used to build a TableMetadata from a sequence of initial
TableUpdates.
+ It is a V1 TableMetadata because there will be a
UpgradeFormatVersionUpdate in
+ initial changes to bump the metadata to the target version.
- def __repr__(self) -> str:
- """Return the string representation of the Catalog class."""
- return f"{self.name} ({self.__class__})"
+ Returns:
+ TableMetadata: An empty TableMetadata instance.
+ """
+ return TableMetadataV1(location="", last_column_id=-1, schema=Schema())
diff --git a/pyiceberg/catalog/dynamodb.py b/pyiceberg/catalog/dynamodb.py
index 266dd635..bc5cbede 100644
--- a/pyiceberg/catalog/dynamodb.py
+++ b/pyiceberg/catalog/dynamodb.py
@@ -33,7 +33,7 @@ from pyiceberg.catalog import (
METADATA_LOCATION,
PREVIOUS_METADATA_LOCATION,
TABLE_TYPE,
- Catalog,
+ MetastoreCatalog,
PropertiesUpdateSummary,
)
from pyiceberg.exceptions import (
@@ -79,7 +79,7 @@ ACTIVE = "ACTIVE"
ITEM = "Item"
-class DynamoDbCatalog(Catalog):
+class DynamoDbCatalog(MetastoreCatalog):
def __init__(self, name: str, **properties: str):
super().__init__(name, **properties)
session = boto3.Session(
diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py
index adec150d..e7532677 100644
--- a/pyiceberg/catalog/glue.py
+++ b/pyiceberg/catalog/glue.py
@@ -45,7 +45,7 @@ from pyiceberg.catalog import (
METADATA_LOCATION,
PREVIOUS_METADATA_LOCATION,
TABLE_TYPE,
- Catalog,
+ MetastoreCatalog,
PropertiesUpdateSummary,
)
from pyiceberg.exceptions import (
@@ -62,8 +62,13 @@ from pyiceberg.io import load_file_io
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
from pyiceberg.schema import Schema, SchemaVisitor, visit
from pyiceberg.serializers import FromInputFile
-from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table,
update_table_metadata
-from pyiceberg.table.metadata import TableMetadata, new_table_metadata
+from pyiceberg.table import (
+ CommitTableRequest,
+ CommitTableResponse,
+ Table,
+ update_table_metadata,
+)
+from pyiceberg.table.metadata import TableMetadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties
from pyiceberg.types import (
@@ -273,7 +278,7 @@ def _register_glue_catalog_id_with_glue_client(glue:
GlueClient, glue_catalog_id
event_system.register("provide-client-params.glue", add_glue_catalog_id)
-class GlueCatalog(Catalog):
+class GlueCatalog(MetastoreCatalog):
def __init__(self, name: str, **properties: Any):
super().__init__(name, **properties)
@@ -384,20 +389,18 @@ class GlueCatalog(Catalog):
ValueError: If the identifier is invalid, or no path is given to
store metadata.
"""
- schema: Schema = self._convert_schema_if_needed(schema) # type: ignore
-
- database_name, table_name =
self.identifier_to_database_and_table(identifier)
-
- location = self._resolve_table_location(location, database_name,
table_name)
- metadata_location = self._get_metadata_location(location=location)
- metadata = new_table_metadata(
- location=location, schema=schema, partition_spec=partition_spec,
sort_order=sort_order, properties=properties
+ staged_table = self._create_staged_table(
+ identifier=identifier,
+ schema=schema,
+ location=location,
+ partition_spec=partition_spec,
+ sort_order=sort_order,
+ properties=properties,
)
- io = load_file_io(properties=self.properties,
location=metadata_location)
- self._write_metadata(metadata, io, metadata_location)
-
- table_input = _construct_table_input(table_name, metadata_location,
properties, metadata)
database_name, table_name =
self.identifier_to_database_and_table(identifier)
+
+ self._write_metadata(staged_table.metadata, staged_table.io,
staged_table.metadata_location)
+ table_input = _construct_table_input(table_name,
staged_table.metadata_location, properties, staged_table.metadata)
self._create_glue_table(database_name=database_name,
table_name=table_name, table_input=table_input)
return self.load_table(identifier=identifier)
@@ -435,46 +438,71 @@ class GlueCatalog(Catalog):
)
database_name, table_name =
self.identifier_to_database_and_table(identifier_tuple)
- current_glue_table = self._get_glue_table(database_name=database_name,
table_name=table_name)
- glue_table_version_id = current_glue_table.get("VersionId")
- if not glue_table_version_id:
- raise CommitFailedException(f"Cannot commit
{database_name}.{table_name} because Glue table version id is missing")
- current_table =
self._convert_glue_to_iceberg(glue_table=current_glue_table)
- base_metadata = current_table.metadata
-
- # Validate the update requirements
- for requirement in table_request.requirements:
- requirement.validate(base_metadata)
-
- updated_metadata = update_table_metadata(base_metadata,
table_request.updates)
- if updated_metadata == base_metadata:
- # no changes, do nothing
- return CommitTableResponse(metadata=base_metadata,
metadata_location=current_table.metadata_location)
-
- # write new metadata
- new_metadata_version =
self._parse_metadata_version(current_table.metadata_location) + 1
- new_metadata_location =
self._get_metadata_location(current_table.metadata.location,
new_metadata_version)
- self._write_metadata(updated_metadata, current_table.io,
new_metadata_location)
-
- update_table_input = _construct_table_input(
- table_name=table_name,
- metadata_location=new_metadata_location,
- properties=current_table.properties,
- metadata=updated_metadata,
- glue_table=current_glue_table,
- prev_metadata_location=current_table.metadata_location,
- )
+ try:
+ current_glue_table =
self._get_glue_table(database_name=database_name, table_name=table_name)
+ # Update the table
+ glue_table_version_id = current_glue_table.get("VersionId")
+ if not glue_table_version_id:
+ raise CommitFailedException(
+ f"Cannot commit {database_name}.{table_name} because Glue
table version id is missing"
+ )
+ current_table =
self._convert_glue_to_iceberg(glue_table=current_glue_table)
+ base_metadata = current_table.metadata
+
+ # Validate the update requirements
+ for requirement in table_request.requirements:
+ requirement.validate(base_metadata)
+
+ updated_metadata =
update_table_metadata(base_metadata=base_metadata,
updates=table_request.updates)
+ if updated_metadata == base_metadata:
+ # no changes, do nothing
+ return CommitTableResponse(metadata=base_metadata,
metadata_location=current_table.metadata_location)
+
+ # write new metadata
+ new_metadata_version =
self._parse_metadata_version(current_table.metadata_location) + 1
+ new_metadata_location =
self._get_metadata_location(current_table.metadata.location,
new_metadata_version)
+ self._write_metadata(updated_metadata, current_table.io,
new_metadata_location)
+
+ update_table_input = _construct_table_input(
+ table_name=table_name,
+ metadata_location=new_metadata_location,
+ properties=current_table.properties,
+ metadata=updated_metadata,
+ glue_table=current_glue_table,
+ prev_metadata_location=current_table.metadata_location,
+ )
- # Pass `version_id` to implement optimistic locking: it ensures
updates are rejected if concurrent
- # modifications occur. See more details at
https://iceberg.apache.org/docs/latest/aws/#optimistic-locking
- self._update_glue_table(
- database_name=database_name,
- table_name=table_name,
- table_input=update_table_input,
- version_id=glue_table_version_id,
- )
+ # Pass `version_id` to implement optimistic locking: it ensures
updates are rejected if concurrent
+ # modifications occur. See more details at
https://iceberg.apache.org/docs/latest/aws/#optimistic-locking
+ self._update_glue_table(
+ database_name=database_name,
+ table_name=table_name,
+ table_input=update_table_input,
+ version_id=glue_table_version_id,
+ )
+
+ return CommitTableResponse(metadata=updated_metadata,
metadata_location=new_metadata_location)
+ except NoSuchTableError:
+ # Create the table
+ updated_metadata = update_table_metadata(
+ base_metadata=self._empty_table_metadata(),
updates=table_request.updates, enforce_validation=True
+ )
+ new_metadata_version = 0
+ new_metadata_location =
self._get_metadata_location(updated_metadata.location, new_metadata_version)
+ self._write_metadata(
+ updated_metadata,
self._load_file_io(updated_metadata.properties, new_metadata_location),
new_metadata_location
+ )
+
+ create_table_input = _construct_table_input(
+ table_name=table_name,
+ metadata_location=new_metadata_location,
+ properties=updated_metadata.properties,
+ metadata=updated_metadata,
+ )
+
+ self._create_glue_table(database_name=database_name,
table_name=table_name, table_input=create_table_input)
- return CommitTableResponse(metadata=updated_metadata,
metadata_location=new_metadata_location)
+ return CommitTableResponse(metadata=updated_metadata,
metadata_location=new_metadata_location)
def load_table(self, identifier: Union[str, Identifier]) -> Table:
"""Load the table's metadata and returns the table instance.
diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py
index 18bbcfe0..359bdef5 100644
--- a/pyiceberg/catalog/hive.py
+++ b/pyiceberg/catalog/hive.py
@@ -58,7 +58,7 @@ from pyiceberg.catalog import (
LOCATION,
METADATA_LOCATION,
TABLE_TYPE,
- Catalog,
+ MetastoreCatalog,
PropertiesUpdateSummary,
)
from pyiceberg.exceptions import (
@@ -230,7 +230,7 @@ class SchemaToHiveConverter(SchemaVisitor[str]):
return HIVE_PRIMITIVE_TYPES[type(primitive)]
-class HiveCatalog(Catalog):
+class HiveCatalog(MetastoreCatalog):
_client: _HiveClient
def __init__(self, name: str, **properties: str):
diff --git a/pyiceberg/catalog/noop.py b/pyiceberg/catalog/noop.py
index e294390e..1dfeb952 100644
--- a/pyiceberg/catalog/noop.py
+++ b/pyiceberg/catalog/noop.py
@@ -28,6 +28,7 @@ from pyiceberg.schema import Schema
from pyiceberg.table import (
CommitTableRequest,
CommitTableResponse,
+ CreateTableTransaction,
Table,
)
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
@@ -49,9 +50,23 @@ class NoopCatalog(Catalog):
) -> Table:
raise NotImplementedError
+ def create_table_transaction(
+ self,
+ identifier: Union[str, Identifier],
+ schema: Union[Schema, "pa.Schema"],
+ location: Optional[str] = None,
+ partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
+ sort_order: SortOrder = UNSORTED_SORT_ORDER,
+ properties: Properties = EMPTY_DICT,
+ ) -> CreateTableTransaction:
+ raise NotImplementedError
+
def load_table(self, identifier: Union[str, Identifier]) -> Table:
raise NotImplementedError
+ def table_exists(self, identifier: Union[str, Identifier]) -> bool:
+ raise NotImplementedError
+
def register_table(self, identifier: Union[str, Identifier],
metadata_location: str) -> Table:
"""Register a new table using existing metadata.
@@ -70,6 +85,9 @@ class NoopCatalog(Catalog):
def drop_table(self, identifier: Union[str, Identifier]) -> None:
raise NotImplementedError
+ def purge_table(self, identifier: Union[str, Identifier]) -> None:
+ raise NotImplementedError
+
def rename_table(self, from_identifier: Union[str, Identifier],
to_identifier: Union[str, Identifier]) -> Table:
raise NotImplementedError
diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py
index 81a9b09f..53e3f6a1 100644
--- a/pyiceberg/catalog/rest.py
+++ b/pyiceberg/catalog/rest.py
@@ -61,6 +61,8 @@ from pyiceberg.schema import Schema, assign_fresh_schema_ids
from pyiceberg.table import (
CommitTableRequest,
CommitTableResponse,
+ CreateTableTransaction,
+ StagedTable,
Table,
TableIdentifier,
)
@@ -135,7 +137,7 @@ _RETRY_ARGS = {
class TableResponse(IcebergBaseModel):
- metadata_location: str = Field(alias="metadata-location")
+ metadata_location: Optional[str] = Field(alias="metadata-location")
metadata: TableMetadata
config: Properties = Field(default_factory=dict)
@@ -460,7 +462,18 @@ class RestCatalog(Catalog):
def _response_to_table(self, identifier_tuple: Tuple[str, ...],
table_response: TableResponse) -> Table:
return Table(
identifier=(self.name,) + identifier_tuple if self.name else
identifier_tuple,
- metadata_location=table_response.metadata_location,
+ metadata_location=table_response.metadata_location, # type: ignore
+ metadata=table_response.metadata,
+ io=self._load_file_io(
+ {**table_response.metadata.properties,
**table_response.config}, table_response.metadata_location
+ ),
+ catalog=self,
+ )
+
+ def _response_to_staged_table(self, identifier_tuple: Tuple[str, ...],
table_response: TableResponse) -> StagedTable:
+ return StagedTable(
+ identifier=(self.name,) + identifier_tuple if self.name else
identifier_tuple,
+ metadata_location=table_response.metadata_location, # type: ignore
metadata=table_response.metadata,
io=self._load_file_io(
{**table_response.metadata.properties,
**table_response.config}, table_response.metadata_location
@@ -490,8 +503,7 @@ class RestCatalog(Catalog):
def _extract_headers_from_properties(self) -> Dict[str, str]:
return {key[len(HEADER_PREFIX) :]: value for key, value in
self.properties.items() if key.startswith(HEADER_PREFIX)}
- @retry(**_RETRY_ARGS)
- def create_table(
+ def _create_table(
self,
identifier: Union[str, Identifier],
schema: Union[Schema, "pa.Schema"],
@@ -499,7 +511,8 @@ class RestCatalog(Catalog):
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
properties: Properties = EMPTY_DICT,
- ) -> Table:
+ stage_create: bool = False,
+ ) -> TableResponse:
iceberg_schema = self._convert_schema_if_needed(schema)
fresh_schema = assign_fresh_schema_ids(iceberg_schema)
fresh_partition_spec = assign_fresh_partition_spec_ids(partition_spec,
iceberg_schema, fresh_schema)
@@ -512,6 +525,7 @@ class RestCatalog(Catalog):
table_schema=fresh_schema,
partition_spec=fresh_partition_spec,
write_order=fresh_sort_order,
+ stage_create=stage_create,
properties=properties,
)
serialized_json = request.model_dump_json().encode(UTF8)
@@ -524,9 +538,51 @@ class RestCatalog(Catalog):
except HTTPError as exc:
self._handle_non_200_response(exc, {409: TableAlreadyExistsError})
- table_response = TableResponse(**response.json())
+ return TableResponse(**response.json())
+
+ @retry(**_RETRY_ARGS)
+ def create_table(
+ self,
+ identifier: Union[str, Identifier],
+ schema: Union[Schema, "pa.Schema"],
+ location: Optional[str] = None,
+ partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
+ sort_order: SortOrder = UNSORTED_SORT_ORDER,
+ properties: Properties = EMPTY_DICT,
+ ) -> Table:
+ table_response = self._create_table(
+ identifier=identifier,
+ schema=schema,
+ location=location,
+ partition_spec=partition_spec,
+ sort_order=sort_order,
+ properties=properties,
+ stage_create=False,
+ )
return self._response_to_table(self.identifier_to_tuple(identifier),
table_response)
+ @retry(**_RETRY_ARGS)
+ def create_table_transaction(
+ self,
+ identifier: Union[str, Identifier],
+ schema: Union[Schema, "pa.Schema"],
+ location: Optional[str] = None,
+ partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
+ sort_order: SortOrder = UNSORTED_SORT_ORDER,
+ properties: Properties = EMPTY_DICT,
+ ) -> CreateTableTransaction:
+ table_response = self._create_table(
+ identifier=identifier,
+ schema=schema,
+ location=location,
+ partition_spec=partition_spec,
+ sort_order=sort_order,
+ properties=properties,
+ stage_create=True,
+ )
+ staged_table =
self._response_to_staged_table(self.identifier_to_tuple(identifier),
table_response)
+ return CreateTableTransaction(staged_table)
+
@retry(**_RETRY_ARGS)
def register_table(self, identifier: Union[str, Identifier],
metadata_location: str) -> Table:
"""Register a new table using existing metadata.
@@ -720,6 +776,14 @@ class RestCatalog(Catalog):
@retry(**_RETRY_ARGS)
def table_exists(self, identifier: Union[str, Identifier]) -> bool:
+ """Check if a table exists.
+
+ Args:
+ identifier (str | Identifier): Table identifier.
+
+ Returns:
+ bool: True if the table exists, False otherwise.
+ """
identifier_tuple = self.identifier_to_tuple_without_catalog(identifier)
response = self._session.head(
self.url(Endpoints.load_table, prefixed=True,
**self._split_identifier_for_path(identifier_tuple))
diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py
index d44d4996..978109b2 100644
--- a/pyiceberg/catalog/sql.py
+++ b/pyiceberg/catalog/sql.py
@@ -43,7 +43,7 @@ from sqlalchemy.orm import (
from pyiceberg.catalog import (
METADATA_LOCATION,
- Catalog,
+ MetastoreCatalog,
PropertiesUpdateSummary,
)
from pyiceberg.exceptions import (
@@ -93,7 +93,7 @@ class IcebergNamespaceProperties(SqlCatalogBaseTable):
property_value: Mapped[str] = mapped_column(String(1000), nullable=False)
-class SqlCatalog(Catalog):
+class SqlCatalog(MetastoreCatalog):
def __init__(self, name: str, **properties: str):
super().__init__(name, **properties)
diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py
index 5f67c05c..4e968eb6 100644
--- a/pyiceberg/table/__init__.py
+++ b/pyiceberg/table/__init__.py
@@ -76,6 +76,7 @@ from pyiceberg.manifest import (
from pyiceberg.partitioning import (
INITIAL_PARTITION_SPEC_ID,
PARTITION_FIELD_ID_START,
+ UNPARTITIONED_PARTITION_SPEC,
PartitionField,
PartitionSpec,
_PartitionNameGenerator,
@@ -111,7 +112,7 @@ from pyiceberg.table.snapshots import (
Summary,
update_snapshot_summaries,
)
-from pyiceberg.table.sorting import SortOrder
+from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.transforms import IdentityTransform, TimeTransform, Transform,
VoidTransform
from pyiceberg.typedef import (
EMPTY_DICT,
@@ -144,7 +145,6 @@ if TYPE_CHECKING:
from pyiceberg.catalog import Catalog
-
ALWAYS_TRUE = AlwaysTrue()
TABLE_ROOT_ID = -1
@@ -402,6 +402,59 @@ class Transaction:
return self._table
+class CreateTableTransaction(Transaction):
+ def _initial_changes(self, table_metadata: TableMetadata) -> None:
+ """Set the initial changes that can reconstruct the initial table
metadata when creating the CreateTableTransaction."""
+ self._updates += (
+ AssignUUIDUpdate(uuid=table_metadata.table_uuid),
+
UpgradeFormatVersionUpdate(format_version=table_metadata.format_version),
+ )
+
+ schema: Schema = table_metadata.schema()
+ self._updates += (
+ AddSchemaUpdate(schema_=schema,
last_column_id=schema.highest_field_id, initial_change=True),
+ SetCurrentSchemaUpdate(schema_id=-1),
+ )
+
+ spec: PartitionSpec = table_metadata.spec()
+ if spec.is_unpartitioned():
+ self._updates +=
(AddPartitionSpecUpdate(spec=UNPARTITIONED_PARTITION_SPEC,
initial_change=True),)
+ else:
+ self._updates += (AddPartitionSpecUpdate(spec=spec,
initial_change=True),)
+ self._updates += (SetDefaultSpecUpdate(spec_id=-1),)
+
+ sort_order: Optional[SortOrder] =
table_metadata.sort_order_by_id(table_metadata.default_sort_order_id)
+ if sort_order is None or sort_order.is_unsorted:
+ self._updates +=
(AddSortOrderUpdate(sort_order=UNSORTED_SORT_ORDER, initial_change=True),)
+ else:
+ self._updates += (AddSortOrderUpdate(sort_order=sort_order,
initial_change=True),)
+ self._updates += (SetDefaultSortOrderUpdate(sort_order_id=-1),)
+
+ self._updates += (
+ SetLocationUpdate(location=table_metadata.location),
+ SetPropertiesUpdate(updates=table_metadata.properties),
+ )
+
+ def __init__(self, table: StagedTable):
+ super().__init__(table, autocommit=False)
+ self._initial_changes(table.metadata)
+
+ def commit_transaction(self) -> Table:
+ """Commit the changes to the catalog.
+
+ In the case of a CreateTableTransaction, the only requirement is
AssertCreate.
+ Returns:
+ The table with the updates applied.
+ """
+ self._requirements = (AssertCreate(),)
+ return super().commit_transaction()
+
+
+class AssignUUIDUpdate(IcebergBaseModel):
+ action: Literal['assign-uuid'] = Field(default="assign-uuid")
+ uuid: uuid.UUID
+
+
class UpgradeFormatVersionUpdate(IcebergBaseModel):
action: Literal['upgrade-format-version'] =
Field(default="upgrade-format-version")
format_version: int = Field(alias="format-version")
@@ -413,6 +466,8 @@ class AddSchemaUpdate(IcebergBaseModel):
# This field is required: https://github.com/apache/iceberg/pull/7445
last_column_id: int = Field(alias="last-column-id")
+ initial_change: bool = Field(default=False, exclude=True)
+
class SetCurrentSchemaUpdate(IcebergBaseModel):
action: Literal['set-current-schema'] = Field(default="set-current-schema")
@@ -425,6 +480,8 @@ class AddPartitionSpecUpdate(IcebergBaseModel):
action: Literal['add-spec'] = Field(default="add-spec")
spec: PartitionSpec
+ initial_change: bool = Field(default=False, exclude=True)
+
class SetDefaultSpecUpdate(IcebergBaseModel):
action: Literal['set-default-spec'] = Field(default="set-default-spec")
@@ -437,6 +494,8 @@ class AddSortOrderUpdate(IcebergBaseModel):
action: Literal['add-sort-order'] = Field(default="add-sort-order")
sort_order: SortOrder = Field(alias="sort-order")
+ initial_change: bool = Field(default=False, exclude=True)
+
class SetDefaultSortOrderUpdate(IcebergBaseModel):
action: Literal['set-default-sort-order'] =
Field(default="set-default-sort-order")
@@ -491,6 +550,7 @@ class RemovePropertiesUpdate(IcebergBaseModel):
TableUpdate = Annotated[
Union[
+ AssignUUIDUpdate,
UpgradeFormatVersionUpdate,
AddSchemaUpdate,
SetCurrentSchemaUpdate,
@@ -527,6 +587,9 @@ class _TableMetadataUpdateContext:
def is_added_schema(self, schema_id: int) -> bool:
return any(update.schema_.schema_id == schema_id for update in
self._updates if isinstance(update, AddSchemaUpdate))
+ def is_added_partition_spec(self, spec_id: int) -> bool:
+ return any(update.spec.spec_id == spec_id for update in self._updates
if isinstance(update, AddPartitionSpecUpdate))
+
def is_added_sort_order(self, sort_order_id: int) -> bool:
return any(
update.sort_order.order_id == sort_order_id for update in
self._updates if isinstance(update, AddSortOrderUpdate)
@@ -549,8 +612,27 @@ def _apply_table_update(update: TableUpdate,
base_metadata: TableMetadata, conte
raise NotImplementedError(f"Unsupported table update: {update}")
+@_apply_table_update.register(AssignUUIDUpdate)
+def _(update: AssignUUIDUpdate, base_metadata: TableMetadata, context:
_TableMetadataUpdateContext) -> TableMetadata:
+ if update.uuid == base_metadata.table_uuid:
+ return base_metadata
+
+ context.add_update(update)
+ return base_metadata.model_copy(update={"table_uuid": update.uuid})
+
+
+@_apply_table_update.register(SetLocationUpdate)
+def _(update: SetLocationUpdate, base_metadata: TableMetadata, context:
_TableMetadataUpdateContext) -> TableMetadata:
+ context.add_update(update)
+ return base_metadata.model_copy(update={"location": update.location})
+
+
@_apply_table_update.register(UpgradeFormatVersionUpdate)
-def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata,
context: _TableMetadataUpdateContext) -> TableMetadata:
+def _(
+ update: UpgradeFormatVersionUpdate,
+ base_metadata: TableMetadata,
+ context: _TableMetadataUpdateContext,
+) -> TableMetadata:
if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION:
raise ValueError(f"Unsupported table format version:
{update.format_version}")
elif update.format_version < base_metadata.format_version:
@@ -595,13 +677,13 @@ def _(update: AddSchemaUpdate, base_metadata:
TableMetadata, context: _TableMeta
if update.last_column_id < base_metadata.last_column_id:
raise ValueError(f"Invalid last column id {update.last_column_id},
must be >= {base_metadata.last_column_id}")
+ metadata_updates: Dict[str, Any] = {
+ "last_column_id": update.last_column_id,
+ "schemas": [update.schema_] if update.initial_change else
base_metadata.schemas + [update.schema_],
+ }
+
context.add_update(update)
- return base_metadata.model_copy(
- update={
- "last_column_id": update.last_column_id,
- "schemas": base_metadata.schemas + [update.schema_],
- }
- )
+ return base_metadata.model_copy(update=metadata_updates)
@_apply_table_update.register(SetCurrentSchemaUpdate)
@@ -627,18 +709,19 @@ def _(update: SetCurrentSchemaUpdate, base_metadata:
TableMetadata, context: _Ta
@_apply_table_update.register(AddPartitionSpecUpdate)
def _(update: AddPartitionSpecUpdate, base_metadata: TableMetadata, context:
_TableMetadataUpdateContext) -> TableMetadata:
for spec in base_metadata.partition_specs:
- if spec.spec_id == update.spec.spec_id:
+ if spec.spec_id == update.spec.spec_id and not update.initial_change:
raise ValueError(f"Partition spec with id {spec.spec_id} already
exists: {spec}")
+
+ metadata_updates: Dict[str, Any] = {
+ "partition_specs": [update.spec] if update.initial_change else
base_metadata.partition_specs + [update.spec],
+ "last_partition_id": max(
+ max([field.field_id for field in update.spec.fields], default=0),
+ base_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1,
+ ),
+ }
+
context.add_update(update)
- return base_metadata.model_copy(
- update={
- "partition_specs": base_metadata.partition_specs + [update.spec],
- "last_partition_id": max(
- max(field.field_id for field in update.spec.fields),
- base_metadata.last_partition_id or PARTITION_FIELD_ID_START -
1,
- ),
- }
- )
+ return base_metadata.model_copy(update=metadata_updates)
@_apply_table_update.register(SetDefaultSpecUpdate)
@@ -646,6 +729,8 @@ def _(update: SetDefaultSpecUpdate, base_metadata:
TableMetadata, context: _Tabl
new_spec_id = update.spec_id
if new_spec_id == -1:
new_spec_id = max(spec.spec_id for spec in
base_metadata.partition_specs)
+ if not context.is_added_partition_spec(new_spec_id):
+ raise ValueError("Cannot set current partition spec to last added
one when no partition spec has been added")
if new_spec_id == base_metadata.default_spec_id:
return base_metadata
found_spec_id = False
@@ -736,13 +821,17 @@ def _(update: AddSortOrderUpdate, base_metadata:
TableMetadata, context: _TableM
context.add_update(update)
return base_metadata.model_copy(
update={
- "sort_orders": base_metadata.sort_orders + [update.sort_order],
+ "sort_orders": [update.sort_order] if update.initial_change else
base_metadata.sort_orders + [update.sort_order],
}
)
@_apply_table_update.register(SetDefaultSortOrderUpdate)
-def _(update: SetDefaultSortOrderUpdate, base_metadata: TableMetadata,
context: _TableMetadataUpdateContext) -> TableMetadata:
+def _(
+ update: SetDefaultSortOrderUpdate,
+ base_metadata: TableMetadata,
+ context: _TableMetadataUpdateContext,
+) -> TableMetadata:
new_sort_order_id = update.sort_order_id
if new_sort_order_id == -1:
# The last added sort order should be in base_metadata.sort_orders at
this point
@@ -761,12 +850,15 @@ def _(update: SetDefaultSortOrderUpdate, base_metadata:
TableMetadata, context:
return base_metadata.model_copy(update={"default_sort_order_id":
new_sort_order_id})
-def update_table_metadata(base_metadata: TableMetadata, updates:
Tuple[TableUpdate, ...]) -> TableMetadata:
+def update_table_metadata(
+ base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...],
enforce_validation: bool = False
+) -> TableMetadata:
"""Update the table metadata with the given updates in one transaction.
Args:
base_metadata: The base metadata to be updated.
updates: The updates in one transaction.
+ enforce_validation: Whether to trigger validation after applying the
updates.
Returns:
The metadata with the updates applied.
@@ -777,7 +869,10 @@ def update_table_metadata(base_metadata: TableMetadata,
updates: Tuple[TableUpda
for update in updates:
new_metadata = _apply_table_update(update, new_metadata, context)
- return new_metadata.model_copy(deep=True)
+ if enforce_validation:
+ return TableMetadataUtil.parse_obj(new_metadata.model_dump())
+ else:
+ return new_metadata.model_copy(deep=True)
class ValidatableTableRequirement(IcebergBaseModel):
@@ -1287,6 +1382,25 @@ class StaticTable(Table):
)
+class StagedTable(Table):
+ def refresh(self) -> Table:
+ raise ValueError("Cannot refresh a staged table")
+
+ def scan(
+ self,
+ row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE,
+ selected_fields: Tuple[str, ...] = ("*",),
+ case_sensitive: bool = True,
+ snapshot_id: Optional[int] = None,
+ options: Properties = EMPTY_DICT,
+ limit: Optional[int] = None,
+ ) -> DataScan:
+ raise ValueError("Cannot scan a staged table")
+
+ def to_daft(self) -> daft.DataFrame:
+ raise ValueError("Cannot convert a staged table to a Daft DataFrame")
+
+
def _parse_row_filter(expr: Union[str, BooleanExpression]) ->
BooleanExpression:
"""Accept an expression in the form of a BooleanExpression or a string.
diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py
index 3e1acf95..2e20c509 100644
--- a/pyiceberg/table/metadata.py
+++ b/pyiceberg/table/metadata.py
@@ -412,7 +412,7 @@ class TableMetadataV1(TableMetadataCommonFields,
IcebergBaseModel):
"""The table’s current schema. (Deprecated: use schemas and
current-schema-id instead)."""
- partition_spec: List[Dict[str, Any]] = Field(alias="partition-spec")
+ partition_spec: List[Dict[str, Any]] = Field(alias="partition-spec",
default_factory=list)
"""The table’s current partition spec, stored as only fields.
Note that this is used by writers to partition data, but is
not used when reading because reads use the specs stored in
diff --git a/tests/catalog/integration_test_glue.py
b/tests/catalog/integration_test_glue.py
index a685b7da..5cd60225 100644
--- a/tests/catalog/integration_test_glue.py
+++ b/tests/catalog/integration_test_glue.py
@@ -24,7 +24,7 @@ import pyarrow as pa
import pytest
from botocore.exceptions import ClientError
-from pyiceberg.catalog import Catalog
+from pyiceberg.catalog import Catalog, MetastoreCatalog
from pyiceberg.catalog.glue import GlueCatalog
from pyiceberg.exceptions import (
NamespaceAlreadyExistsError,
@@ -35,6 +35,7 @@ from pyiceberg.exceptions import (
)
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.schema import Schema
+from pyiceberg.table import _dataframe_to_data_files
from pyiceberg.types import IntegerType
from tests.conftest import clean_up, get_bucket_name, get_s3_path
@@ -120,7 +121,7 @@ def test_create_table(
assert table.identifier == (CATALOG_NAME,) + identifier
metadata_location = table.metadata_location.split(get_bucket_name())[1][1:]
s3.head_object(Bucket=get_bucket_name(), Key=metadata_location)
- assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+ assert MetastoreCatalog._parse_metadata_version(table.metadata_location)
== 0
table.append(
pa.Table.from_pylist(
@@ -184,7 +185,7 @@ def test_create_table_with_default_location(
assert table.identifier == (CATALOG_NAME,) + identifier
metadata_location = table.metadata_location.split(get_bucket_name())[1][1:]
s3.head_object(Bucket=get_bucket_name(), Key=metadata_location)
- assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+ assert MetastoreCatalog._parse_metadata_version(table.metadata_location)
== 0
def test_create_table_with_invalid_database(test_catalog: Catalog,
table_schema_nested: Schema, table_name: str) -> None:
@@ -217,7 +218,7 @@ def test_load_table(test_catalog: Catalog,
table_schema_nested: Schema, table_na
assert table.identifier == loaded_table.identifier
assert table.metadata_location == loaded_table.metadata_location
assert table.metadata == loaded_table.metadata
- assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+ assert MetastoreCatalog._parse_metadata_version(table.metadata_location)
== 0
def test_list_tables(test_catalog: Catalog, table_schema_nested: Schema,
database_name: str, table_list: List[str]) -> None:
@@ -239,7 +240,7 @@ def test_rename_table(
new_table_name = f"rename-{table_name}"
identifier = (database_name, table_name)
table = test_catalog.create_table(identifier, table_schema_nested)
- assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+ assert MetastoreCatalog._parse_metadata_version(table.metadata_location)
== 0
assert table.identifier == (CATALOG_NAME,) + identifier
new_identifier = (new_database_name, new_table_name)
test_catalog.rename_table(identifier, new_identifier)
@@ -385,7 +386,7 @@ def test_commit_table_update_schema(
table = test_catalog.create_table(identifier, table_schema_nested)
original_table_metadata = table.metadata
- assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+ assert MetastoreCatalog._parse_metadata_version(table.metadata_location)
== 0
assert original_table_metadata.current_schema_id == 0
assert athena.get_query_results(f'SELECT * FROM
"{database_name}"."{table_name}"') == [
@@ -410,7 +411,7 @@ def test_commit_table_update_schema(
updated_table_metadata = table.metadata
- assert test_catalog._parse_metadata_version(table.metadata_location) == 1
+ assert MetastoreCatalog._parse_metadata_version(table.metadata_location)
== 1
assert updated_table_metadata.current_schema_id == 1
assert len(updated_table_metadata.schemas) == 2
new_schema = next(schema for schema in updated_table_metadata.schemas if
schema.schema_id == 1)
@@ -466,7 +467,7 @@ def test_commit_table_properties(test_catalog: Catalog,
table_schema_nested: Sch
test_catalog.create_namespace(namespace=database_name)
table = test_catalog.create_table(identifier=identifier,
schema=table_schema_nested, properties={"test_a": "test_a"})
- assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+ assert MetastoreCatalog._parse_metadata_version(table.metadata_location)
== 0
transaction = table.transaction()
transaction.set_properties(test_a="test_aa", test_b="test_b",
test_c="test_c")
@@ -474,5 +475,78 @@ def test_commit_table_properties(test_catalog: Catalog,
table_schema_nested: Sch
transaction.commit_transaction()
updated_table_metadata = table.metadata
- assert test_catalog._parse_metadata_version(table.metadata_location) == 1
+ assert MetastoreCatalog._parse_metadata_version(table.metadata_location)
== 1
assert updated_table_metadata.properties == {"test_a": "test_aa",
"test_c": "test_c"}
+
+
[email protected]("format_version", [1, 2])
+def test_create_table_transaction(
+ test_catalog: Catalog,
+ s3: boto3.client,
+ table_schema_nested: Schema,
+ table_name: str,
+ database_name: str,
+ athena: AthenaQueryHelper,
+ format_version: int,
+) -> None:
+ identifier = (database_name, table_name)
+ test_catalog.create_namespace(database_name)
+
+ with test_catalog.create_table_transaction(
+ identifier,
+ table_schema_nested,
+ get_s3_path(get_bucket_name(), database_name, table_name),
+ properties={"format-version": format_version},
+ ) as txn:
+ df = pa.Table.from_pylist(
+ [
+ {
+ "foo": "foo_val",
+ "bar": 1,
+ "baz": False,
+ "qux": ["x", "y"],
+ "quux": {"key": {"subkey": 2}},
+ "location": [{"latitude": 1.1}],
+ "person": {"name": "some_name", "age": 23},
+ }
+ ],
+ schema=schema_to_pyarrow(txn.table_metadata.schema()),
+ )
+
+ with txn.update_snapshot().fast_append() as update_snapshot:
+ data_files = _dataframe_to_data_files(
+ table_metadata=txn.table_metadata,
write_uuid=update_snapshot.commit_uuid, df=df, io=txn._table.io
+ )
+ for data_file in data_files:
+ update_snapshot.append_data_file(data_file)
+
+ table = test_catalog.load_table(identifier)
+ assert table.identifier == (CATALOG_NAME,) + identifier
+ metadata_location = table.metadata_location.split(get_bucket_name())[1][1:]
+ s3.head_object(Bucket=get_bucket_name(), Key=metadata_location)
+ assert MetastoreCatalog._parse_metadata_version(table.metadata_location)
== 0
+
+ assert athena.get_query_results(f'SELECT * FROM
"{database_name}"."{table_name}"') == [
+ {
+ "Data": [
+ {"VarCharValue": "foo"},
+ {"VarCharValue": "bar"},
+ {"VarCharValue": "baz"},
+ {"VarCharValue": "qux"},
+ {"VarCharValue": "quux"},
+ {"VarCharValue": "location"},
+ {"VarCharValue": "person"},
+ ]
+ },
+ {
+ "Data": [
+ {"VarCharValue": "foo_val"},
+ {"VarCharValue": "1"},
+ {"VarCharValue": "false"},
+ {"VarCharValue": "[x, y]"},
+ {"VarCharValue": "{key={subkey=2}}"},
+ {"VarCharValue": "[{latitude=1.1, longitude=null}]"},
+ {"VarCharValue": "{name=some_name, age=23}"},
+ ]
+ },
+ ]
diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py
index 5f78eb3b..8ea04e3f 100644
--- a/tests/catalog/test_base.py
+++ b/tests/catalog/test_base.py
@@ -34,6 +34,7 @@ from pytest_lazyfixture import lazy_fixture
from pyiceberg.catalog import (
Catalog,
+ MetastoreCatalog,
PropertiesUpdateSummary,
)
from pyiceberg.exceptions import (
@@ -65,7 +66,7 @@ from pyiceberg.types import IntegerType, LongType, NestedField
DEFAULT_WAREHOUSE_LOCATION = "file:///tmp/warehouse"
-class InMemoryCatalog(Catalog):
+class InMemoryCatalog(MetastoreCatalog):
"""
An in-memory catalog implementation that uses in-memory data-structures to
store the namespaces and tables.
diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py
index d4ed085c..8aa49186 100644
--- a/tests/catalog/test_glue.py
+++ b/tests/catalog/test_glue.py
@@ -33,7 +33,9 @@ from pyiceberg.exceptions import (
TableAlreadyExistsError,
)
from pyiceberg.io.pyarrow import schema_to_pyarrow
+from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema
+from pyiceberg.transforms import IdentityTransform
from pyiceberg.types import IntegerType
from tests.conftest import BUCKET_NAME, TABLE_METADATA_LOCATION_REGEX
@@ -758,3 +760,49 @@ def test_commit_overwrite_table_snapshot_properties(
assert summary is not None
assert summary["snapshot_prop_a"] is None
assert summary["snapshot_prop_b"] == "test_prop_b"
+
+
+@mock_aws
[email protected]("format_version", [1, 2])
+def test_create_table_transaction(
+ _glue: boto3.client,
+ _bucket_initialize: None,
+ moto_endpoint_url: str,
+ table_schema_nested: Schema,
+ database_name: str,
+ table_name: str,
+ format_version: int,
+) -> None:
+ catalog_name = "glue"
+ identifier = (database_name, table_name)
+ test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint":
moto_endpoint_url, "warehouse": f"s3://{BUCKET_NAME}"})
+ test_catalog.create_namespace(namespace=database_name)
+
+ with test_catalog.create_table_transaction(
+ identifier,
+ table_schema_nested,
+ partition_spec=PartitionSpec(PartitionField(source_id=1,
field_id=1000, transform=IdentityTransform(), name="foo")),
+ properties={"format-version": format_version},
+ ) as txn:
+ with txn.update_schema() as update_schema:
+ update_schema.add_column(path="b", field_type=IntegerType())
+
+ with txn.update_spec() as update_spec:
+ update_spec.add_identity("bar")
+
+ txn.set_properties(test_a="test_aa", test_b="test_b", test_c="test_c")
+
+ table = test_catalog.load_table(identifier)
+
+ assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
+ assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+ assert table.format_version == format_version
+ assert table.schema().find_field("b").field_type == IntegerType()
+ assert table.properties == {"test_a": "test_aa", "test_b": "test_b",
"test_c": "test_c"}
+ assert table.spec().last_assigned_field_id == 1001
+ assert table.spec().fields_by_source_id(1)[0].name == "foo"
+ assert table.spec().fields_by_source_id(1)[0].field_id == 1000
+ assert table.spec().fields_by_source_id(1)[0].transform ==
IdentityTransform()
+ assert table.spec().fields_by_source_id(2)[0].name == "bar"
+ assert table.spec().fields_by_source_id(2)[0].field_id == 1001
+ assert table.spec().fields_by_source_id(2)[0].transform ==
IdentityTransform()
diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py
index 0186e662..e8ad6b08 100644
--- a/tests/integration/test_writes.py
+++ b/tests/integration/test_writes.py
@@ -680,6 +680,60 @@ def test_write_and_evolve(session_catalog: Catalog,
format_version: int) -> None
snapshot_update.append_data_file(data_file)
[email protected]
[email protected]("format_version", [2])
+def test_create_table_transaction(session_catalog: Catalog, format_version:
int) -> None:
+ if format_version == 1:
+ pytest.skip(
+ "There is a bug in the REST catalog (maybe server side) that
prevents create and commit a staged version 1 table"
+ )
+
+ identifier = f"default.arrow_create_table_transaction{format_version}"
+
+ try:
+ session_catalog.drop_table(identifier=identifier)
+ except NoSuchTableError:
+ pass
+
+ pa_table = pa.Table.from_pydict(
+ {
+ 'foo': ['a', None, 'z'],
+ },
+ schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
+ )
+
+ pa_table_with_column = pa.Table.from_pydict(
+ {
+ 'foo': ['a', None, 'z'],
+ 'bar': [19, None, 25],
+ },
+ schema=pa.schema([
+ pa.field("foo", pa.string(), nullable=True),
+ pa.field("bar", pa.int32(), nullable=True),
+ ]),
+ )
+
+ with session_catalog.create_table_transaction(
+ identifier=identifier, schema=pa_table.schema,
properties={"format-version": str(format_version)}
+ ) as txn:
+ with txn.update_snapshot().fast_append() as snapshot_update:
+ for data_file in
_dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table,
io=txn._table.io):
+ snapshot_update.append_data_file(data_file)
+
+ with txn.update_schema() as schema_txn:
+ schema_txn.union_by_name(pa_table_with_column.schema)
+
+ with txn.update_snapshot().fast_append() as snapshot_update:
+ for data_file in _dataframe_to_data_files(
+ table_metadata=txn.table_metadata, df=pa_table_with_column,
io=txn._table.io
+ ):
+ snapshot_update.append_data_file(data_file)
+
+ tbl = session_catalog.load_table(identifier=identifier)
+ assert tbl.format_version == format_version
+ assert len(tbl.scan().to_arrow()) == 6
+
+
@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_table_properties_int_value(