This is an automated email from the ASF dual-hosted git repository.
honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git
The following commit(s) were added to refs/heads/main by this push:
new 4391919 Set Glue Table Information when creating/updating tables
(#288)
4391919 is described below
commit 43919199ea961db25aa855b6bbe5b370d252fe21
Author: Michael Marino <[email protected]>
AuthorDate: Thu Jan 25 02:44:34 2024 +0100
Set Glue Table Information when creating/updating tables (#288)
* Set Glue Table Information when creating/updating tables
* Add integration tests for glue/Athena
---
pyiceberg/catalog/glue.py | 113 ++++++++++++++++++++++-
tests/catalog/integration_test_glue.py | 162 ++++++++++++++++++++++++++++++++-
tests/catalog/test_glue.py | 45 ++++++++-
3 files changed, 312 insertions(+), 8 deletions(-)
diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py
index 6cf9462..bccbfa4 100644
--- a/pyiceberg/catalog/glue.py
+++ b/pyiceberg/catalog/glue.py
@@ -18,6 +18,7 @@
from typing import (
Any,
+ Dict,
List,
Optional,
Set,
@@ -28,6 +29,7 @@ from typing import (
import boto3
from mypy_boto3_glue.client import GlueClient
from mypy_boto3_glue.type_defs import (
+ ColumnTypeDef,
DatabaseInputTypeDef,
DatabaseTypeDef,
StorageDescriptorTypeDef,
@@ -59,12 +61,32 @@ from pyiceberg.exceptions import (
)
from pyiceberg.io import load_file_io
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
-from pyiceberg.schema import Schema
+from pyiceberg.schema import Schema, SchemaVisitor, visit
from pyiceberg.serializers import FromInputFile
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table,
update_table_metadata
-from pyiceberg.table.metadata import new_table_metadata
+from pyiceberg.table.metadata import TableMetadata, new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT
+from pyiceberg.types import (
+ BinaryType,
+ BooleanType,
+ DateType,
+ DecimalType,
+ DoubleType,
+ FixedType,
+ FloatType,
+ IntegerType,
+ ListType,
+ LongType,
+ MapType,
+ NestedField,
+ PrimitiveType,
+ StringType,
+ StructType,
+ TimestampType,
+ TimeType,
+ UUIDType,
+)
# If Glue should skip archiving an old table version when creating a new
version in a commit. By
# default, Glue archives all old table versions after an UpdateTable call, but
Glue has a default
@@ -73,6 +95,10 @@ from pyiceberg.typedef import EMPTY_DICT
GLUE_SKIP_ARCHIVE = "glue.skip-archive"
GLUE_SKIP_ARCHIVE_DEFAULT = True
+ICEBERG_FIELD_ID = "iceberg.field.id"
+ICEBERG_FIELD_OPTIONAL = "iceberg.field.optional"
+ICEBERG_FIELD_CURRENT = "iceberg.field.current"
+
def _construct_parameters(
metadata_location: str, glue_table: Optional[TableTypeDef] = None,
prev_metadata_location: Optional[str] = None
@@ -84,10 +110,86 @@ def _construct_parameters(
return new_parameters
+GLUE_PRIMITIVE_TYPES = {
+ BooleanType: "boolean",
+ IntegerType: "int",
+ LongType: "bigint",
+ FloatType: "float",
+ DoubleType: "double",
+ DateType: "date",
+ TimeType: "string",
+ StringType: "string",
+ UUIDType: "string",
+ TimestampType: "timestamp",
+ FixedType: "binary",
+ BinaryType: "binary",
+}
+
+
+class _IcebergSchemaToGlueType(SchemaVisitor[str]):
+ def schema(self, schema: Schema, struct_result: str) -> str:
+ return struct_result
+
+ def struct(self, struct: StructType, field_results: List[str]) -> str:
+ return f"struct<{','.join(field_results)}>"
+
+ def field(self, field: NestedField, field_result: str) -> str:
+ return f"{field.name}:{field_result}"
+
+ def list(self, list_type: ListType, element_result: str) -> str:
+ return f"array<{element_result}>"
+
+ def map(self, map_type: MapType, key_result: str, value_result: str) ->
str:
+ return f"map<{key_result},{value_result}>"
+
+ def primitive(self, primitive: PrimitiveType) -> str:
+ if isinstance(primitive, DecimalType):
+ return f"decimal({primitive.precision},{primitive.scale})"
+ if (primitive_type := type(primitive)) not in GLUE_PRIMITIVE_TYPES:
+ raise ValueError(f"Unknown primitive type: {primitive}")
+ return GLUE_PRIMITIVE_TYPES[primitive_type]
+
+
+def _to_columns(metadata: TableMetadata) -> List[ColumnTypeDef]:
+ results: Dict[str, ColumnTypeDef] = {}
+
+ def _append_to_results(field: NestedField, is_current: bool) -> None:
+ if field.name in results:
+ return
+
+ results[field.name] = cast(
+ ColumnTypeDef,
+ {
+ "Name": field.name,
+ "Type": visit(field.field_type, _IcebergSchemaToGlueType()),
+ "Parameters": {
+ ICEBERG_FIELD_ID: str(field.field_id),
+ ICEBERG_FIELD_OPTIONAL: str(field.optional).lower(),
+ ICEBERG_FIELD_CURRENT: str(is_current).lower(),
+ },
+ },
+ )
+ if field.doc:
+ results[field.name]["Comment"] = field.doc
+
+ if current_schema := metadata.schema_by_id(metadata.current_schema_id):
+ for field in current_schema.columns:
+ _append_to_results(field, True)
+
+ for schema in metadata.schemas:
+ if schema.schema_id == metadata.current_schema_id:
+ continue
+ for field in schema.columns:
+ _append_to_results(field, False)
+
+ return list(results.values())
+
+
def _construct_table_input(
table_name: str,
metadata_location: str,
properties: Properties,
+ metadata: TableMetadata,
glue_table: Optional[TableTypeDef] = None,
prev_metadata_location: Optional[str] = None,
) -> TableInputTypeDef:
@@ -95,6 +197,10 @@ def _construct_table_input(
"Name": table_name,
"TableType": EXTERNAL_TABLE,
"Parameters": _construct_parameters(metadata_location, glue_table,
prev_metadata_location),
+ "StorageDescriptor": {
+ "Columns": _to_columns(metadata),
+ "Location": metadata.location,
+ },
}
if "Description" in properties:
@@ -258,7 +364,7 @@ class GlueCatalog(Catalog):
io = load_file_io(properties=self.properties,
location=metadata_location)
self._write_metadata(metadata, io, metadata_location)
- table_input = _construct_table_input(table_name, metadata_location,
properties)
+ table_input = _construct_table_input(table_name, metadata_location,
properties, metadata)
database_name, table_name =
self.identifier_to_database_and_table(identifier)
self._create_glue_table(database_name=database_name,
table_name=table_name, table_input=table_input)
@@ -322,6 +428,7 @@ class GlueCatalog(Catalog):
table_name=table_name,
metadata_location=new_metadata_location,
properties=current_table.properties,
+ metadata=updated_metadata,
glue_table=current_glue_table,
prev_metadata_location=current_table.metadata_location,
)
diff --git a/tests/catalog/integration_test_glue.py
b/tests/catalog/integration_test_glue.py
index 24401ca..a56e4c6 100644
--- a/tests/catalog/integration_test_glue.py
+++ b/tests/catalog/integration_test_glue.py
@@ -15,9 +15,12 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Generator, List
+import time
+from typing import Any, Dict, Generator, List
+from uuid import uuid4
import boto3
+import pyarrow as pa
import pytest
from botocore.exceptions import ClientError
@@ -30,6 +33,7 @@ from pyiceberg.exceptions import (
NoSuchTableError,
TableAlreadyExistsError,
)
+from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.schema import Schema
from pyiceberg.types import IntegerType
from tests.conftest import clean_up, get_bucket_name, get_s3_path
@@ -52,8 +56,62 @@ def fixture_test_catalog() -> Generator[Catalog, None, None]:
clean_up(test_catalog)
+class AthenaQueryHelper:
+ _athena_client: boto3.client
+ _s3_resource: boto3.resource
+ _output_bucket: str
+ _output_path: str
+
+ def __init__(self) -> None:
+ self._s3_resource = boto3.resource("s3")
+ self._athena_client = boto3.client("athena")
+ self._output_bucket = get_bucket_name()
+ self._output_path = f"athena_results_{uuid4()}"
+
+ def get_query_results(self, query: str) -> List[Dict[str, Any]]:
+ query_execution_id = self._athena_client.start_query_execution(
+ QueryString=query, ResultConfiguration={"OutputLocation":
f"s3://{self._output_bucket}/{self._output_path}"}
+ )["QueryExecutionId"]
+
+ while True:
+ result =
self._athena_client.get_query_execution(QueryExecutionId=query_execution_id)["QueryExecution"]["Status"]
+ query_status = result["State"]
+ assert query_status not in [
+ "FAILED",
+ "CANCELLED",
+ ], f"""
+ Athena query with the string failed or was cancelled:
+ Query: {query}
+ Status: {query_status}
+ Reason: {result["StateChangeReason"]}"""
+
+ if query_status not in ["QUEUED", "RUNNING"]:
+ break
+ time.sleep(0.5)
+
+ # No pagination for now, assume that we are not doing large queries
+ return
self._athena_client.get_query_results(QueryExecutionId=query_execution_id)["ResultSet"]["Rows"]
+
+ def clean_up(self) -> None:
+ bucket = self._s3_resource.Bucket(self._output_bucket)
+ for obj in bucket.objects.filter(Prefix=f"{self._output_path}/"):
+ self._s3_resource.Object(bucket.name, obj.key).delete()
+
+
[email protected](name="athena", scope="module")
+def fixture_athena_helper() -> Generator[AthenaQueryHelper, None, None]:
+ query_helper = AthenaQueryHelper()
+ yield query_helper
+ query_helper.clean_up()
+
+
def test_create_table(
- test_catalog: Catalog, s3: boto3.client, table_schema_nested: Schema,
table_name: str, database_name: str
+ test_catalog: Catalog,
+ s3: boto3.client,
+ table_schema_nested: Schema,
+ table_name: str,
+ database_name: str,
+ athena: AthenaQueryHelper,
) -> None:
identifier = (database_name, table_name)
test_catalog.create_namespace(database_name)
@@ -64,6 +122,48 @@ def test_create_table(
s3.head_object(Bucket=get_bucket_name(), Key=metadata_location)
assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+ table.append(
+ pa.Table.from_pylist(
+ [
+ {
+ "foo": "foo_val",
+ "bar": 1,
+ "baz": False,
+ "qux": ["x", "y"],
+ "quux": {"key": {"subkey": 2}},
+ "location": [{"latitude": 1.1}],
+ "person": {"name": "some_name", "age": 23},
+ }
+ ],
+ schema=schema_to_pyarrow(table.schema()),
+ ),
+ )
+
+ assert athena.get_query_results(f'SELECT * FROM
"{database_name}"."{table_name}"') == [
+ {
+ "Data": [
+ {"VarCharValue": "foo"},
+ {"VarCharValue": "bar"},
+ {"VarCharValue": "baz"},
+ {"VarCharValue": "qux"},
+ {"VarCharValue": "quux"},
+ {"VarCharValue": "location"},
+ {"VarCharValue": "person"},
+ ]
+ },
+ {
+ "Data": [
+ {"VarCharValue": "foo_val"},
+ {"VarCharValue": "1"},
+ {"VarCharValue": "false"},
+ {"VarCharValue": "[x, y]"},
+ {"VarCharValue": "{key={subkey=2}}"},
+ {"VarCharValue": "[{latitude=1.1, longitude=null}]"},
+ {"VarCharValue": "{name=some_name, age=23}"},
+ ]
+ },
+ ]
+
def test_create_table_with_invalid_location(table_schema_nested: Schema,
table_name: str, database_name: str) -> None:
identifier = (database_name, table_name)
@@ -269,7 +369,7 @@ def test_update_namespace_properties(test_catalog: Catalog,
database_name: str)
def test_commit_table_update_schema(
- test_catalog: Catalog, table_schema_nested: Schema, database_name: str,
table_name: str
+ test_catalog: Catalog, table_schema_nested: Schema, database_name: str,
table_name: str, athena: AthenaQueryHelper
) -> None:
identifier = (database_name, table_name)
test_catalog.create_namespace(namespace=database_name)
@@ -279,6 +379,20 @@ def test_commit_table_update_schema(
assert test_catalog._parse_metadata_version(table.metadata_location) == 0
assert original_table_metadata.current_schema_id == 0
+ assert athena.get_query_results(f'SELECT * FROM
"{database_name}"."{table_name}"') == [
+ {
+ "Data": [
+ {"VarCharValue": "foo"},
+ {"VarCharValue": "bar"},
+ {"VarCharValue": "baz"},
+ {"VarCharValue": "qux"},
+ {"VarCharValue": "quux"},
+ {"VarCharValue": "location"},
+ {"VarCharValue": "person"},
+ ]
+ }
+ ]
+
transaction = table.transaction()
update = transaction.update_schema()
update.add_column(path="b", field_type=IntegerType())
@@ -295,6 +409,48 @@ def test_commit_table_update_schema(
assert new_schema == update._apply()
assert new_schema.find_field("b").field_type == IntegerType()
+ table.append(
+ pa.Table.from_pylist(
+ [
+ {
+ "foo": "foo_val",
+ "bar": 1,
+ "location": [{"latitude": 1.1}],
+ "person": {"name": "some_name", "age": 23},
+ "b": 2,
+ }
+ ],
+ schema=schema_to_pyarrow(new_schema),
+ ),
+ )
+
+ assert athena.get_query_results(f'SELECT * FROM
"{database_name}"."{table_name}"') == [
+ {
+ "Data": [
+ {"VarCharValue": "foo"},
+ {"VarCharValue": "bar"},
+ {"VarCharValue": "baz"},
+ {"VarCharValue": "qux"},
+ {"VarCharValue": "quux"},
+ {"VarCharValue": "location"},
+ {"VarCharValue": "person"},
+ {"VarCharValue": "b"},
+ ]
+ },
+ {
+ "Data": [
+ {"VarCharValue": "foo_val"},
+ {"VarCharValue": "1"},
+ {},
+ {"VarCharValue": "[]"},
+ {"VarCharValue": "{}"},
+ {"VarCharValue": "[{latitude=1.1, longitude=null}]"},
+ {"VarCharValue": "{name=some_name, age=23}"},
+ {"VarCharValue": "2"},
+ ]
+ },
+ ]
+
def test_commit_table_properties(test_catalog: Catalog, table_schema_nested:
Schema, database_name: str, table_name: str) -> None:
identifier = (database_name, table_name)
diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py
index bf6d117..b1f1371 100644
--- a/tests/catalog/test_glue.py
+++ b/tests/catalog/test_glue.py
@@ -38,7 +38,12 @@ from tests.conftest import BUCKET_NAME,
TABLE_METADATA_LOCATION_REGEX
@mock_glue
def test_create_table_with_database_location(
- _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested:
Schema, database_name: str, table_name: str
+ _glue: boto3.client,
+ _bucket_initialize: None,
+ moto_endpoint_url: str,
+ table_schema_nested: Schema,
+ database_name: str,
+ table_name: str,
) -> None:
catalog_name = "glue"
identifier = (database_name, table_name)
@@ -49,6 +54,22 @@ def test_create_table_with_database_location(
assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+ # Ensure schema is also pushed to Glue
+ table_info = _glue.get_table(
+ DatabaseName=database_name,
+ Name=table_name,
+ )
+ storage_descriptor = table_info["Table"]["StorageDescriptor"]
+ columns = storage_descriptor["Columns"]
+ assert len(columns) == len(table_schema_nested.fields)
+ assert columns[0] == {
+ "Name": "foo",
+ "Type": "string",
+ "Parameters": {"iceberg.field.id": "1", "iceberg.field.optional":
"true", "iceberg.field.current": "true"},
+ }
+
+ assert storage_descriptor["Location"] ==
f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
+
@mock_glue
def test_create_table_with_default_warehouse(
@@ -524,7 +545,12 @@ def test_passing_profile_name() -> None:
@mock_glue
def test_commit_table_update_schema(
- _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested:
Schema, database_name: str, table_name: str
+ _glue: boto3.client,
+ _bucket_initialize: None,
+ moto_endpoint_url: str,
+ table_schema_nested: Schema,
+ database_name: str,
+ table_name: str,
) -> None:
catalog_name = "glue"
identifier = (database_name, table_name)
@@ -554,6 +580,21 @@ def test_commit_table_update_schema(
assert new_schema == update._apply()
assert new_schema.find_field("b").field_type == IntegerType()
+ # Ensure schema is also pushed to Glue
+ table_info = _glue.get_table(
+ DatabaseName=database_name,
+ Name=table_name,
+ )
+ storage_descriptor = table_info["Table"]["StorageDescriptor"]
+ columns = storage_descriptor["Columns"]
+ assert len(columns) == len(table_schema_nested.fields) + 1
+ assert columns[-1] == {
+ "Name": "b",
+ "Type": "int",
+ "Parameters": {"iceberg.field.id": "18", "iceberg.field.optional":
"true", "iceberg.field.current": "true"},
+ }
+ assert storage_descriptor["Location"] ==
f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
+
@mock_glue
def test_commit_table_properties(