This is an automated email from the ASF dual-hosted git repository.

honahx pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg-python.git


The following commit(s) were added to refs/heads/main by this push:
     new d872245f Remove trailing slash from table location when creating a 
table (#702)
d872245f is described below

commit d872245f143b919b25fda90a7b7c7fb0729a402a
Author: Felix Scherz <[email protected]>
AuthorDate: Mon May 6 19:32:09 2024 +0200

    Remove trailing slash from table location when creating a table (#702)
---
 pyiceberg/catalog/__init__.py  |   2 +-
 pyiceberg/catalog/rest.py      |   2 +
 tests/catalog/test_base.py     |  14 ++++
 tests/catalog/test_dynamodb.py |  15 ++++
 tests/catalog/test_glue.py     |  16 ++++
 tests/catalog/test_hive.py     | 175 +++++++++++++++++++++++++++++++++++++++++
 tests/catalog/test_rest.py     |  25 ++++++
 tests/catalog/test_sql.py      |  22 ++++++
 8 files changed, 270 insertions(+), 1 deletion(-)

diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py
index 18d803fe..5bb9ec27 100644
--- a/pyiceberg/catalog/__init__.py
+++ b/pyiceberg/catalog/__init__.py
@@ -779,7 +779,7 @@ class MetastoreCatalog(Catalog, ABC):
     def _resolve_table_location(self, location: Optional[str], database_name: 
str, table_name: str) -> str:
         if not location:
             return self._get_default_warehouse_location(database_name, 
table_name)
-        return location
+        return location.rstrip("/")
 
     def _get_default_warehouse_location(self, database_name: str, table_name: 
str) -> str:
         database_properties = self.load_namespace_properties(database_name)
diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py
index 53e3f6a1..565d8091 100644
--- a/pyiceberg/catalog/rest.py
+++ b/pyiceberg/catalog/rest.py
@@ -519,6 +519,8 @@ class RestCatalog(Catalog):
         fresh_sort_order = assign_fresh_sort_order_ids(sort_order, 
iceberg_schema, fresh_schema)
 
         namespace_and_table = self._split_identifier_for_path(identifier)
+        if location:
+            location = location.rstrip("/")
         request = CreateTableRequest(
             name=namespace_and_table["table"],
             location=location,
diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py
index 7d5e0a97..06e9a8a3 100644
--- a/tests/catalog/test_base.py
+++ b/tests/catalog/test_base.py
@@ -105,6 +105,7 @@ class InMemoryCatalog(MetastoreCatalog):
 
             if not location:
                 location = f'{self._warehouse_location}/{"/".join(identifier)}'
+            location = location.rstrip("/")
 
             metadata_location = self._get_metadata_location(location=location)
             metadata = new_table_metadata(
@@ -353,6 +354,19 @@ def test_create_table_location_override(catalog: 
InMemoryCatalog) -> None:
     assert table.location() == new_location
 
 
+def test_create_table_removes_trailing_slash_from_location(catalog: 
InMemoryCatalog) -> None:
+    new_location = f"{catalog._warehouse_location}/new_location"
+    table = catalog.create_table(
+        identifier=TEST_TABLE_IDENTIFIER,
+        schema=TEST_TABLE_SCHEMA,
+        location=f"{new_location}/",
+        partition_spec=TEST_TABLE_PARTITION_SPEC,
+        properties=TEST_TABLE_PROPERTIES,
+    )
+    assert catalog.load_table(TEST_TABLE_IDENTIFIER) == table
+    assert table.location() == new_location
+
+
 @pytest.mark.parametrize(
     "schema,expected",
     [
diff --git a/tests/catalog/test_dynamodb.py b/tests/catalog/test_dynamodb.py
index 1c647cf8..f4b16d34 100644
--- a/tests/catalog/test_dynamodb.py
+++ b/tests/catalog/test_dynamodb.py
@@ -117,6 +117,21 @@ def test_create_table_with_given_location(
     assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
 
 
+@mock_aws
+def test_create_table_removes_trailing_slash_in_location(
+    _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: 
Schema, database_name: str, table_name: str
+) -> None:
+    catalog_name = "test_ddb_catalog"
+    identifier = (database_name, table_name)
+    test_catalog = DynamoDbCatalog(catalog_name, **{"s3.endpoint": 
moto_endpoint_url})
+    test_catalog.create_namespace(namespace=database_name)
+    location = f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
+    table = test_catalog.create_table(identifier=identifier, 
schema=table_schema_nested, location=f"{location}/")
+    assert table.identifier == (catalog_name,) + identifier
+    assert table.location() == location
+    assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
+
+
 @mock_aws
 def test_create_table_with_no_location(
     _bucket_initialize: None, table_schema_nested: Schema, database_name: str, 
table_name: str
diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py
index 5999b192..5b67b92c 100644
--- a/tests/catalog/test_glue.py
+++ b/tests/catalog/test_glue.py
@@ -137,6 +137,22 @@ def test_create_table_with_given_location(
     assert test_catalog._parse_metadata_version(table.metadata_location) == 0
 
 
+@mock_aws
+def test_create_table_removes_trailing_slash_in_location(
+    _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: 
Schema, database_name: str, table_name: str
+) -> None:
+    catalog_name = "glue"
+    identifier = (database_name, table_name)
+    test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": 
moto_endpoint_url})
+    test_catalog.create_namespace(namespace=database_name)
+    location = f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
+    table = test_catalog.create_table(identifier=identifier, 
schema=table_schema_nested, location=f"{location}/")
+    assert table.identifier == (catalog_name,) + identifier
+    assert table.location() == location
+    assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
+    assert test_catalog._parse_metadata_version(table.metadata_location) == 0
+
+
 @mock_aws
 def test_create_table_with_pyarrow_schema(
     _bucket_initialize: None,
diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py
index 70927ea1..af3a3801 100644
--- a/tests/catalog/test_hive.py
+++ b/tests/catalog/test_hive.py
@@ -365,6 +365,181 @@ def test_create_table(
     assert metadata.model_dump() == expected.model_dump()
 
 
[email protected]("hive2_compatible", [True, False])
+@patch("time.time", MagicMock(return_value=12345))
+def test_create_table_with_given_location_removes_trailing_slash(
+    table_schema_with_all_types: Schema, hive_database: HiveDatabase, 
hive_table: HiveTable, hive2_compatible: bool
+) -> None:
+    catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
+    if hive2_compatible:
+        catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL, 
**{"hive.hive2-compatible": "true"})
+
+    location = f"{hive_database.locationUri}/table-given-location"
+
+    catalog._client = MagicMock()
+    catalog._client.__enter__().create_table.return_value = None
+    catalog._client.__enter__().get_table.return_value = hive_table
+    catalog._client.__enter__().get_database.return_value = hive_database
+    catalog.create_table(
+        ("default", "table"), schema=table_schema_with_all_types, 
properties={"owner": "javaberg"}, location=f"{location}/"
+    )
+
+    called_hive_table: HiveTable = 
catalog._client.__enter__().create_table.call_args[0][0]
+    # This one is generated within the function itself, so we need to extract
+    # it to construct the assert_called_with
+    metadata_location: str = called_hive_table.parameters["metadata_location"]
+    assert metadata_location.endswith(".metadata.json")
+    assert "/database/table-given-location/metadata/" in metadata_location
+    catalog._client.__enter__().create_table.assert_called_with(
+        HiveTable(
+            tableName="table",
+            dbName="default",
+            owner="javaberg",
+            createTime=12345,
+            lastAccessTime=12345,
+            retention=None,
+            sd=StorageDescriptor(
+                cols=[
+                    FieldSchema(name='boolean', type='boolean', comment=None),
+                    FieldSchema(name='integer', type='int', comment=None),
+                    FieldSchema(name='long', type='bigint', comment=None),
+                    FieldSchema(name='float', type='float', comment=None),
+                    FieldSchema(name='double', type='double', comment=None),
+                    FieldSchema(name='decimal', type='decimal(32,3)', 
comment=None),
+                    FieldSchema(name='date', type='date', comment=None),
+                    FieldSchema(name='time', type='string', comment=None),
+                    FieldSchema(name='timestamp', type='timestamp', 
comment=None),
+                    FieldSchema(
+                        name='timestamptz',
+                        type='timestamp' if hive2_compatible else 'timestamp 
with local time zone',
+                        comment=None,
+                    ),
+                    FieldSchema(name='string', type='string', comment=None),
+                    FieldSchema(name='uuid', type='string', comment=None),
+                    FieldSchema(name='fixed', type='binary', comment=None),
+                    FieldSchema(name='binary', type='binary', comment=None),
+                    FieldSchema(name='list', type='array<string>', 
comment=None),
+                    FieldSchema(name='map', type='map<string,int>', 
comment=None),
+                    FieldSchema(name='struct', 
type='struct<inner_string:string,inner_int:int>', comment=None),
+                ],
+                location=f"{hive_database.locationUri}/table-given-location",
+                inputFormat="org.apache.hadoop.mapred.FileInputFormat",
+                outputFormat="org.apache.hadoop.mapred.FileOutputFormat",
+                compressed=None,
+                numBuckets=None,
+                serdeInfo=SerDeInfo(
+                    name=None,
+                    
serializationLib="org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe",
+                    parameters=None,
+                    description=None,
+                    serializerClass=None,
+                    deserializerClass=None,
+                    serdeType=None,
+                ),
+                bucketCols=None,
+                sortCols=None,
+                parameters=None,
+                skewedInfo=None,
+                storedAsSubDirectories=None,
+            ),
+            partitionKeys=None,
+            parameters={"EXTERNAL": "TRUE", "table_type": "ICEBERG", 
"metadata_location": metadata_location},
+            viewOriginalText=None,
+            viewExpandedText=None,
+            tableType="EXTERNAL_TABLE",
+            privileges=None,
+            temporary=False,
+            rewriteEnabled=None,
+            creationMetadata=None,
+            catName=None,
+            ownerType=1,
+            writeId=-1,
+            isStatsCompliant=None,
+            colStats=None,
+            accessType=None,
+            requiredReadCapabilities=None,
+            requiredWriteCapabilities=None,
+            id=None,
+            fileMetadata=None,
+            dictionary=None,
+            txnId=None,
+        )
+    )
+
+    with open(metadata_location, encoding=UTF8) as f:
+        payload = f.read()
+
+    metadata = TableMetadataUtil.parse_raw(payload)
+
+    assert "database/table-given-location" in metadata.location
+
+    expected = TableMetadataV2(
+        location=metadata.location,
+        table_uuid=metadata.table_uuid,
+        last_updated_ms=metadata.last_updated_ms,
+        last_column_id=22,
+        schemas=[
+            Schema(
+                NestedField(field_id=1, name='boolean', 
field_type=BooleanType(), required=True),
+                NestedField(field_id=2, name='integer', 
field_type=IntegerType(), required=True),
+                NestedField(field_id=3, name='long', field_type=LongType(), 
required=True),
+                NestedField(field_id=4, name='float', field_type=FloatType(), 
required=True),
+                NestedField(field_id=5, name='double', 
field_type=DoubleType(), required=True),
+                NestedField(field_id=6, name='decimal', 
field_type=DecimalType(precision=32, scale=3), required=True),
+                NestedField(field_id=7, name='date', field_type=DateType(), 
required=True),
+                NestedField(field_id=8, name='time', field_type=TimeType(), 
required=True),
+                NestedField(field_id=9, name='timestamp', 
field_type=TimestampType(), required=True),
+                NestedField(field_id=10, name='timestamptz', 
field_type=TimestamptzType(), required=True),
+                NestedField(field_id=11, name='string', 
field_type=StringType(), required=True),
+                NestedField(field_id=12, name='uuid', field_type=UUIDType(), 
required=True),
+                NestedField(field_id=13, name='fixed', 
field_type=FixedType(length=12), required=True),
+                NestedField(field_id=14, name='binary', 
field_type=BinaryType(), required=True),
+                NestedField(
+                    field_id=15,
+                    name='list',
+                    field_type=ListType(type='list', element_id=18, 
element_type=StringType(), element_required=True),
+                    required=True,
+                ),
+                NestedField(
+                    field_id=16,
+                    name='map',
+                    field_type=MapType(
+                        type='map', key_id=19, key_type=StringType(), 
value_id=20, value_type=IntegerType(), value_required=True
+                    ),
+                    required=True,
+                ),
+                NestedField(
+                    field_id=17,
+                    name='struct',
+                    field_type=StructType(
+                        NestedField(field_id=21, name='inner_string', 
field_type=StringType(), required=False),
+                        NestedField(field_id=22, name='inner_int', 
field_type=IntegerType(), required=True),
+                    ),
+                    required=False,
+                ),
+                schema_id=0,
+                identifier_field_ids=[2],
+            )
+        ],
+        current_schema_id=0,
+        last_partition_id=999,
+        properties={"owner": "javaberg", 'write.parquet.compression-codec': 
'zstd'},
+        partition_specs=[PartitionSpec()],
+        default_spec_id=0,
+        current_snapshot_id=None,
+        snapshots=[],
+        snapshot_log=[],
+        metadata_log=[],
+        sort_orders=[SortOrder(order_id=0)],
+        default_sort_order_id=0,
+        refs={},
+        format_version=2,
+        last_sequence_number=0,
+    )
+
+    assert metadata.model_dump() == expected.model_dump()
+
+
 @patch("time.time", MagicMock(return_value=12345))
 def test_create_v1_table(table_schema_simple: Schema, hive_database: 
HiveDatabase, hive_table: HiveTable) -> None:
     catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py
index 15ddb01b..b8410d68 100644
--- a/tests/catalog/test_rest.py
+++ b/tests/catalog/test_rest.py
@@ -732,6 +732,31 @@ def test_create_table_200(
     assert actual == expected
 
 
+def test_create_table_with_given_location_removes_trailing_slash_200(
+    rest_mock: Mocker, table_schema_simple: Schema, 
example_table_metadata_no_snapshot_v1_rest_json: Dict[str, Any]
+) -> None:
+    rest_mock.post(
+        f"{TEST_URI}v1/namespaces/fokko/tables",
+        json=example_table_metadata_no_snapshot_v1_rest_json,
+        status_code=200,
+        request_headers=TEST_HEADERS,
+    )
+    catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN)
+    location = "s3://warehouse/database/table-custom-location"
+    catalog.create_table(
+        identifier=("fokko", "fokko2"),
+        schema=table_schema_simple,
+        location=f"{location}/",
+        partition_spec=PartitionSpec(
+            PartitionField(source_id=1, field_id=1000, 
transform=TruncateTransform(width=3), name="id"), spec_id=1
+        ),
+        sort_order=SortOrder(SortField(source_id=2, 
transform=IdentityTransform())),
+        properties={"owner": "fokko"},
+    )
+    assert rest_mock.last_request
+    assert rest_mock.last_request.json()["location"] == location
+
+
 def test_create_table_409(rest_mock: Mocker, table_schema_simple: Schema) -> 
None:
     rest_mock.post(
         f"{TEST_URI}v1/namespaces/fokko/tables",
diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py
index 40a1566e..97965268 100644
--- a/tests/catalog/test_sql.py
+++ b/tests/catalog/test_sql.py
@@ -264,6 +264,28 @@ def test_create_table_with_default_warehouse_location(
     catalog.drop_table(random_identifier)
 
 
[email protected](
+    'catalog',
+    [
+        lazy_fixture('catalog_memory'),
+        lazy_fixture('catalog_sqlite'),
+    ],
+)
+def test_create_table_with_given_location_removes_trailing_slash(
+    warehouse: Path, catalog: SqlCatalog, table_schema_nested: Schema, 
random_identifier: Identifier
+) -> None:
+    database_name, table_name = random_identifier
+    location = f"file://{warehouse}/{database_name}.db/{table_name}-given"
+    catalog.create_namespace(database_name)
+    catalog.create_table(random_identifier, table_schema_nested, 
location=f"{location}/")
+    table = catalog.load_table(random_identifier)
+    assert table.identifier == (catalog.name,) + random_identifier
+    assert table.metadata_location.startswith(f"file://{warehouse}")
+    assert os.path.exists(table.metadata_location[len("file://") :])
+    assert table.location() == location
+    catalog.drop_table(random_identifier)
+
+
 @pytest.mark.parametrize(
     'catalog',
     [

Reply via email to