This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 4391919  Set Glue Table Information when creating/updating tables 
(#288)
4391919 is described below

commit 43919199ea961db25aa855b6bbe5b370d252fe21
Author: Michael Marino <[email protected]>
AuthorDate: Thu Jan 25 02:44:34 2024 +0100

    Set Glue Table Information when creating/updating tables (#288)
    
    * Set Glue Table Information when creating/updating tables
    * Add integration tests for glue/Athena
---
 pyiceberg/catalog/glue.py              | 113 ++++++++++++++++++++++-
 tests/catalog/integration_test_glue.py | 162 ++++++++++++++++++++++++++++++++-
 tests/catalog/test_glue.py             |  45 ++++++++-
 3 files changed, 312 insertions(+), 8 deletions(-)

diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py
index 6cf9462..bccbfa4 100644
--- a/pyiceberg/catalog/glue.py
+++ b/pyiceberg/catalog/glue.py
@@ -18,6 +18,7 @@
 
 from typing import (
     Any,
+    Dict,
     List,
     Optional,
     Set,
@@ -28,6 +29,7 @@ from typing import (
 import boto3
 from mypy_boto3_glue.client import GlueClient
 from mypy_boto3_glue.type_defs import (
+    ColumnTypeDef,
     DatabaseInputTypeDef,
     DatabaseTypeDef,
     StorageDescriptorTypeDef,
@@ -59,12 +61,32 @@ from pyiceberg.exceptions import (
 )
 from pyiceberg.io import load_file_io
 from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
-from pyiceberg.schema import Schema
+from pyiceberg.schema import Schema, SchemaVisitor, visit
 from pyiceberg.serializers import FromInputFile
 from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, 
update_table_metadata
-from pyiceberg.table.metadata import new_table_metadata
+from pyiceberg.table.metadata import TableMetadata, new_table_metadata
 from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
 from pyiceberg.typedef import EMPTY_DICT
+from pyiceberg.types import (
+    BinaryType,
+    BooleanType,
+    DateType,
+    DecimalType,
+    DoubleType,
+    FixedType,
+    FloatType,
+    IntegerType,
+    ListType,
+    LongType,
+    MapType,
+    NestedField,
+    PrimitiveType,
+    StringType,
+    StructType,
+    TimestampType,
+    TimeType,
+    UUIDType,
+)
 
 # If Glue should skip archiving an old table version when creating a new 
version in a commit. By
 # default, Glue archives all old table versions after an UpdateTable call, but 
Glue has a default
@@ -73,6 +95,10 @@ from pyiceberg.typedef import EMPTY_DICT
 GLUE_SKIP_ARCHIVE = "glue.skip-archive"
 GLUE_SKIP_ARCHIVE_DEFAULT = True
 
+ICEBERG_FIELD_ID = "iceberg.field.id"
+ICEBERG_FIELD_OPTIONAL = "iceberg.field.optional"
+ICEBERG_FIELD_CURRENT = "iceberg.field.current"
+
 
 def _construct_parameters(
     metadata_location: str, glue_table: Optional[TableTypeDef] = None, 
prev_metadata_location: Optional[str] = None
@@ -84,10 +110,86 @@ def _construct_parameters(
     return new_parameters
 
 
+GLUE_PRIMITIVE_TYPES = {
+    BooleanType: "boolean",
+    IntegerType: "int",
+    LongType: "bigint",
+    FloatType: "float",
+    DoubleType: "double",
+    DateType: "date",
+    TimeType: "string",
+    StringType: "string",
+    UUIDType: "string",
+    TimestampType: "timestamp",
+    FixedType: "binary",
+    BinaryType: "binary",
+}
+
+
+class _IcebergSchemaToGlueType(SchemaVisitor[str]):
+    def schema(self, schema: Schema, struct_result: str) -> str:
+        return struct_result
+
+    def struct(self, struct: StructType, field_results: List[str]) -> str:
+        return f"struct<{','.join(field_results)}>"
+
+    def field(self, field: NestedField, field_result: str) -> str:
+        return f"{field.name}:{field_result}"
+
+    def list(self, list_type: ListType, element_result: str) -> str:
+        return f"array<{element_result}>"
+
+    def map(self, map_type: MapType, key_result: str, value_result: str) -> 
str:
+        return f"map<{key_result},{value_result}>"
+
+    def primitive(self, primitive: PrimitiveType) -> str:
+        if isinstance(primitive, DecimalType):
+            return f"decimal({primitive.precision},{primitive.scale})"
+        if (primitive_type := type(primitive)) not in GLUE_PRIMITIVE_TYPES:
+            raise ValueError(f"Unknown primitive type: {primitive}")
+        return GLUE_PRIMITIVE_TYPES[primitive_type]
+
+
+def _to_columns(metadata: TableMetadata) -> List[ColumnTypeDef]:
+    results: Dict[str, ColumnTypeDef] = {}
+
+    def _append_to_results(field: NestedField, is_current: bool) -> None:
+        if field.name in results:
+            return
+
+        results[field.name] = cast(
+            ColumnTypeDef,
+            {
+                "Name": field.name,
+                "Type": visit(field.field_type, _IcebergSchemaToGlueType()),
+                "Parameters": {
+                    ICEBERG_FIELD_ID: str(field.field_id),
+                    ICEBERG_FIELD_OPTIONAL: str(field.optional).lower(),
+                    ICEBERG_FIELD_CURRENT: str(is_current).lower(),
+                },
+            },
+        )
+        if field.doc:
+            results[field.name]["Comment"] = field.doc
+
+    if current_schema := metadata.schema_by_id(metadata.current_schema_id):
+        for field in current_schema.columns:
+            _append_to_results(field, True)
+
+    for schema in metadata.schemas:
+        if schema.schema_id == metadata.current_schema_id:
+            continue
+        for field in schema.columns:
+            _append_to_results(field, False)
+
+    return list(results.values())
+
+
 def _construct_table_input(
     table_name: str,
     metadata_location: str,
     properties: Properties,
+    metadata: TableMetadata,
     glue_table: Optional[TableTypeDef] = None,
     prev_metadata_location: Optional[str] = None,
 ) -> TableInputTypeDef:
@@ -95,6 +197,10 @@ def _construct_table_input(
         "Name": table_name,
         "TableType": EXTERNAL_TABLE,
         "Parameters": _construct_parameters(metadata_location, glue_table, 
prev_metadata_location),
+        "StorageDescriptor": {
+            "Columns": _to_columns(metadata),
+            "Location": metadata.location,
+        },
     }
 
     if "Description" in properties:
@@ -258,7 +364,7 @@ class GlueCatalog(Catalog):
         io = load_file_io(properties=self.properties, 
location=metadata_location)
         self._write_metadata(metadata, io, metadata_location)
 
-        table_input = _construct_table_input(table_name, metadata_location, 
properties)
+        table_input = _construct_table_input(table_name, metadata_location, 
properties, metadata)
         database_name, table_name = 
self.identifier_to_database_and_table(identifier)
         self._create_glue_table(database_name=database_name, 
table_name=table_name, table_input=table_input)
 
@@ -322,6 +428,7 @@ class GlueCatalog(Catalog):
             table_name=table_name,
             metadata_location=new_metadata_location,
             properties=current_table.properties,
+            metadata=updated_metadata,
             glue_table=current_glue_table,
             prev_metadata_location=current_table.metadata_location,
         )
diff --git a/tests/catalog/integration_test_glue.py 
b/tests/catalog/integration_test_glue.py
index 24401ca..a56e4c6 100644
--- a/tests/catalog/integration_test_glue.py
+++ b/tests/catalog/integration_test_glue.py
@@ -15,9 +15,12 @@
 #  specific language governing permissions and limitations
 #  under the License.
 
-from typing import Generator, List
+import time
+from typing import Any, Dict, Generator, List
+from uuid import uuid4
 
 import boto3
+import pyarrow as pa
 import pytest
 from botocore.exceptions import ClientError
 
@@ -30,6 +33,7 @@ from pyiceberg.exceptions import (
     NoSuchTableError,
     TableAlreadyExistsError,
 )
+from pyiceberg.io.pyarrow import schema_to_pyarrow
 from pyiceberg.schema import Schema
 from pyiceberg.types import IntegerType
 from tests.conftest import clean_up, get_bucket_name, get_s3_path
@@ -52,8 +56,62 @@ def fixture_test_catalog() -> Generator[Catalog, None, None]:
     clean_up(test_catalog)
 
 
+class AthenaQueryHelper:
+    _athena_client: boto3.client
+    _s3_resource: boto3.resource
+    _output_bucket: str
+    _output_path: str
+
+    def __init__(self) -> None:
+        self._s3_resource = boto3.resource("s3")
+        self._athena_client = boto3.client("athena")
+        self._output_bucket = get_bucket_name()
+        self._output_path = f"athena_results_{uuid4()}"
+
+    def get_query_results(self, query: str) -> List[Dict[str, Any]]:
+        query_execution_id = self._athena_client.start_query_execution(
+            QueryString=query, ResultConfiguration={"OutputLocation": 
f"s3://{self._output_bucket}/{self._output_path}"}
+        )["QueryExecutionId"]
+
+        while True:
+            result = 
self._athena_client.get_query_execution(QueryExecutionId=query_execution_id)["QueryExecution"]["Status"]
+            query_status = result["State"]
+            assert query_status not in [
+                "FAILED",
+                "CANCELLED",
+            ], f"""
+   Athena query with the string failed or was cancelled:
+   Query: {query}
+   Status: {query_status}
+   Reason: {result["StateChangeReason"]}"""
+
+            if query_status not in ["QUEUED", "RUNNING"]:
+                break
+            time.sleep(0.5)
+
+        # No pagination for now, assume that we are not doing large queries
+        return 
self._athena_client.get_query_results(QueryExecutionId=query_execution_id)["ResultSet"]["Rows"]
+
+    def clean_up(self) -> None:
+        bucket = self._s3_resource.Bucket(self._output_bucket)
+        for obj in bucket.objects.filter(Prefix=f"{self._output_path}/"):
+            self._s3_resource.Object(bucket.name, obj.key).delete()
+
+
[email protected](name="athena", scope="module")
+def fixture_athena_helper() -> Generator[AthenaQueryHelper, None, None]:
+    query_helper = AthenaQueryHelper()
+    yield query_helper
+    query_helper.clean_up()
+
+
 def test_create_table(
-    test_catalog: Catalog, s3: boto3.client, table_schema_nested: Schema, 
table_name: str, database_name: str
+    test_catalog: Catalog,
+    s3: boto3.client,
+    table_schema_nested: Schema,
+    table_name: str,
+    database_name: str,
+    athena: AthenaQueryHelper,
 ) -> None:
     identifier = (database_name, table_name)
     test_catalog.create_namespace(database_name)
@@ -64,6 +122,48 @@ def test_create_table(
     s3.head_object(Bucket=get_bucket_name(), Key=metadata_location)
     assert test_catalog._parse_metadata_version(table.metadata_location) == 0
 
+    table.append(
+        pa.Table.from_pylist(
+            [
+                {
+                    "foo": "foo_val",
+                    "bar": 1,
+                    "baz": False,
+                    "qux": ["x", "y"],
+                    "quux": {"key": {"subkey": 2}},
+                    "location": [{"latitude": 1.1}],
+                    "person": {"name": "some_name", "age": 23},
+                }
+            ],
+            schema=schema_to_pyarrow(table.schema()),
+        ),
+    )
+
+    assert athena.get_query_results(f'SELECT * FROM 
"{database_name}"."{table_name}"') == [
+        {
+            "Data": [
+                {"VarCharValue": "foo"},
+                {"VarCharValue": "bar"},
+                {"VarCharValue": "baz"},
+                {"VarCharValue": "qux"},
+                {"VarCharValue": "quux"},
+                {"VarCharValue": "location"},
+                {"VarCharValue": "person"},
+            ]
+        },
+        {
+            "Data": [
+                {"VarCharValue": "foo_val"},
+                {"VarCharValue": "1"},
+                {"VarCharValue": "false"},
+                {"VarCharValue": "[x, y]"},
+                {"VarCharValue": "{key={subkey=2}}"},
+                {"VarCharValue": "[{latitude=1.1, longitude=null}]"},
+                {"VarCharValue": "{name=some_name, age=23}"},
+            ]
+        },
+    ]
+
 
 def test_create_table_with_invalid_location(table_schema_nested: Schema, 
table_name: str, database_name: str) -> None:
     identifier = (database_name, table_name)
@@ -269,7 +369,7 @@ def test_update_namespace_properties(test_catalog: Catalog, 
database_name: str)
 
 
 def test_commit_table_update_schema(
-    test_catalog: Catalog, table_schema_nested: Schema, database_name: str, 
table_name: str
+    test_catalog: Catalog, table_schema_nested: Schema, database_name: str, 
table_name: str, athena: AthenaQueryHelper
 ) -> None:
     identifier = (database_name, table_name)
     test_catalog.create_namespace(namespace=database_name)
@@ -279,6 +379,20 @@ def test_commit_table_update_schema(
     assert test_catalog._parse_metadata_version(table.metadata_location) == 0
     assert original_table_metadata.current_schema_id == 0
 
+    assert athena.get_query_results(f'SELECT * FROM 
"{database_name}"."{table_name}"') == [
+        {
+            "Data": [
+                {"VarCharValue": "foo"},
+                {"VarCharValue": "bar"},
+                {"VarCharValue": "baz"},
+                {"VarCharValue": "qux"},
+                {"VarCharValue": "quux"},
+                {"VarCharValue": "location"},
+                {"VarCharValue": "person"},
+            ]
+        }
+    ]
+
     transaction = table.transaction()
     update = transaction.update_schema()
     update.add_column(path="b", field_type=IntegerType())
@@ -295,6 +409,48 @@ def test_commit_table_update_schema(
     assert new_schema == update._apply()
     assert new_schema.find_field("b").field_type == IntegerType()
 
+    table.append(
+        pa.Table.from_pylist(
+            [
+                {
+                    "foo": "foo_val",
+                    "bar": 1,
+                    "location": [{"latitude": 1.1}],
+                    "person": {"name": "some_name", "age": 23},
+                    "b": 2,
+                }
+            ],
+            schema=schema_to_pyarrow(new_schema),
+        ),
+    )
+
+    assert athena.get_query_results(f'SELECT * FROM 
"{database_name}"."{table_name}"') == [
+        {
+            "Data": [
+                {"VarCharValue": "foo"},
+                {"VarCharValue": "bar"},
+                {"VarCharValue": "baz"},
+                {"VarCharValue": "qux"},
+                {"VarCharValue": "quux"},
+                {"VarCharValue": "location"},
+                {"VarCharValue": "person"},
+                {"VarCharValue": "b"},
+            ]
+        },
+        {
+            "Data": [
+                {"VarCharValue": "foo_val"},
+                {"VarCharValue": "1"},
+                {},
+                {"VarCharValue": "[]"},
+                {"VarCharValue": "{}"},
+                {"VarCharValue": "[{latitude=1.1, longitude=null}]"},
+                {"VarCharValue": "{name=some_name, age=23}"},
+                {"VarCharValue": "2"},
+            ]
+        },
+    ]
+
 
 def test_commit_table_properties(test_catalog: Catalog, table_schema_nested: 
Schema, database_name: str, table_name: str) -> None:
     identifier = (database_name, table_name)
diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py
index bf6d117..b1f1371 100644
--- a/tests/catalog/test_glue.py
+++ b/tests/catalog/test_glue.py
@@ -38,7 +38,12 @@ from tests.conftest import BUCKET_NAME, 
TABLE_METADATA_LOCATION_REGEX
 
 @mock_glue
 def test_create_table_with_database_location(
-    _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: 
Schema, database_name: str, table_name: str
+    _glue: boto3.client,
+    _bucket_initialize: None,
+    moto_endpoint_url: str,
+    table_schema_nested: Schema,
+    database_name: str,
+    table_name: str,
 ) -> None:
     catalog_name = "glue"
     identifier = (database_name, table_name)
@@ -49,6 +54,22 @@ def test_create_table_with_database_location(
     assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
     assert test_catalog._parse_metadata_version(table.metadata_location) == 0
 
+    # Ensure schema is also pushed to Glue
+    table_info = _glue.get_table(
+        DatabaseName=database_name,
+        Name=table_name,
+    )
+    storage_descriptor = table_info["Table"]["StorageDescriptor"]
+    columns = storage_descriptor["Columns"]
+    assert len(columns) == len(table_schema_nested.fields)
+    assert columns[0] == {
+        "Name": "foo",
+        "Type": "string",
+        "Parameters": {"iceberg.field.id": "1", "iceberg.field.optional": 
"true", "iceberg.field.current": "true"},
+    }
+
+    assert storage_descriptor["Location"] == 
f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
+
 
 @mock_glue
 def test_create_table_with_default_warehouse(
@@ -524,7 +545,12 @@ def test_passing_profile_name() -> None:
 
 @mock_glue
 def test_commit_table_update_schema(
-    _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: 
Schema, database_name: str, table_name: str
+    _glue: boto3.client,
+    _bucket_initialize: None,
+    moto_endpoint_url: str,
+    table_schema_nested: Schema,
+    database_name: str,
+    table_name: str,
 ) -> None:
     catalog_name = "glue"
     identifier = (database_name, table_name)
@@ -554,6 +580,21 @@ def test_commit_table_update_schema(
     assert new_schema == update._apply()
     assert new_schema.find_field("b").field_type == IntegerType()
 
+    # Ensure schema is also pushed to Glue
+    table_info = _glue.get_table(
+        DatabaseName=database_name,
+        Name=table_name,
+    )
+    storage_descriptor = table_info["Table"]["StorageDescriptor"]
+    columns = storage_descriptor["Columns"]
+    assert len(columns) == len(table_schema_nested.fields) + 1
+    assert columns[-1] == {
+        "Name": "b",
+        "Type": "int",
+        "Parameters": {"iceberg.field.id": "18", "iceberg.field.optional": 
"true", "iceberg.field.current": "true"},
+    }
+    assert storage_descriptor["Location"] == 
f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
+
 
 @mock_glue
 def test_commit_table_properties(

Reply via email to